Skip to article frontmatterSkip to article content

A Short Tutorial on Matrix Derivative

The University of Texas at Austin

I notice that the matrix derivative technique might be useful for the contents of the course, and fortunately I know a good tutorial introducing the technique, but unfortunately the materials are in Chineses, therefore I translate and summarize the main idea of the tutorial and hope it could be helpful.

The contents are heavily borrowed from the post by 长躯鬼侠 (ECE PhD at CMU) on Zhihu (Chinese Quora).

The relation between differential and matrix derivative

Consider the matrix function f(X):XRf(X): \mathcal{X}\to\mathbb{R}, which maps a matrix XRn×dX\in\mathbb{R}^{n\times d} into a real number. This is the common setting for calculus of variantions (e.g., backpropagation of neural networks, optimal control), as the loss/control objective should always be a real number to be optimized and evaluated.

Recall why we are good at calculating the usual derivative: we basically calculate the derivative based on the composition of simple rules. For the matrix case, the chain rule of derivatives could be error-prone, that’s what makes the problem complicated. But fortunately, the composition of differentials still holds. Indeed, the matrix derivative and differential are connected by the following relation

df=tr(fXTdX),df = \textrm{tr}\left(\dfrac{\partial f}{\partial X}^TdX\right),

where dfRdf\in\mathbb{R} has the same shape with ff and dX,fXRn×ddX,\dfrac{\partial f}{\partial X}\in\mathbb{R}^{n\times d} has the same shape with XX, tr\textrm{tr} donotes the trace operator. Note that since dfdf is a scalar, tr(df)=df\textrm{tr}(df)=df, thus we could add trace on both sides.

In all, to calculate the matrix derivative fX\dfrac{\partial f}{\partial X}, our plan would be:

  • Take the differential of ff w.r.t. XX, by the composition of differentials.
  • Add trace on both sides, and arange the terms into the key relation by trace tricks.
  • Directly readout the derivative from the relation between differential and matrix derivative.

Composition of differentials

  1. Addition: d(X±Y)=dX±dYd(X\pm Y) = dX\pm dY; Matrix multiplication: d(XY)=(dX)Y+XdYd(XY)=(dX)Y + XdY; Transpose: d(XT)=(dX)Td(X^T)=(dX)^T; Trace: dtr(X)=tr(dX)d\textrm{tr}(X)=\textrm{tr}(dX).
  2. Inverse: dX1=X1dXX1dX^{-1}=-X^{-1}dXX^{-1}.
  3. Determinant: dX=tr(XdX)d|X|=\textrm{tr}(X^*dX), where X=XTX^*=\overline{X}^T denotes the conjugate transpose.
  4. Element-wise multiplication: d(XY)=dXY+XdYd(X\odot Y)=dX\odot Y + X\odot dY.
  5. Element-wise function: dσ(X)=σ(X)dXd\sigma(X)=\sigma'(X)\odot dX.

Trace tricks

  1. Scalar: a=tr(a)a=\textrm{tr}(a).
  2. Transpose: tr(AT)=tr(A)\textrm{tr}(A^T)=\textrm{tr}(A).
  3. Linear: tr(A±B)=tr(A)±tr(B)\textrm{tr}(A\pm B)=\textrm{tr}(A)\pm \textrm{tr}(B).
  4. Commutativity of matrix multiplication: tr(AB)=tr(BA)\textrm{tr}(AB)=\textrm{tr}(BA), where AA and BTB^T have the same shape.
  5. Commutativity of matrix/element-wise multiplication: tr(AT(BC))=tr((AB)TC)\textrm{tr}(A^T(B\odot C))=\textrm{tr}((A\odot B)^TC).

Example: non-linear regression

Consider a system Y=σ(WX)Y=\sigma(WX), XRd×nX\in\mathbb{R}^{d\times n} consists of nn data points, each is a dd-dimensional vector, WRm×dW\in\mathbb{R}^{m\times d}, YRm×nY\in\mathbb{R}^{m\times n}, σ()\sigma(\cdot) is some nonlinear function. Denote the label matrix as Y^Rm×n\hat{Y}\in\mathbb{R}^{m\times n}, then the MSE loss

L=12ntr[(YY^)T(YY^)].L = \frac{1}{2n}tr\left[(Y-\hat{Y})^T(Y-\hat{Y})\right].

Calculate LW\dfrac{\partial L}{\partial W}.

By composition of differentials

dL=12ntr[dYT(YY^)+(YY^)TdY]=1ntr[(YY^)TdY]dL = \frac{1}{2n}\textrm{tr}\left[dY^T(Y-\hat{Y})+(Y-\hat{Y})^TdY\right]=\frac{1}{n}\textrm{tr}\left[(Y-\hat{Y})^TdY\right]
=1ntr[(YY^)Tdσ(WX)]=1ntr[(YY^)T(σ(WX)(dWX))].=\frac{1}{n}\textrm{tr}\left[(Y-\hat{Y})^Td\sigma(WX)\right] = \frac{1}{n}\textrm{tr}\left[(Y-\hat{Y})^T(\sigma'(WX)\odot (dW X))\right].

By trace trick 5

dL=1ntr[[(YY^)σ(WX)]TdWX].dL = \frac{1}{n}\textrm{tr}\left[[(Y-\hat{Y})\odot\sigma'(WX)]^TdW X\right].

By trace trick 4

dL=1ntr[X[(YY^)σ(WX)]TdW].dL = \frac{1}{n}\textrm{tr}\left[X[(Y-\hat{Y})\odot\sigma'(WX)]^TdW\right].

Therefore, from the relation between differential and matrix derivative, we have

LW=[(YY^)σ(WX)]XT.\dfrac{\partial L}{\partial W} = \left[(Y-\hat{Y})\odot\sigma'(WX)\right]X^T.