I want to prove $\nabla_A J = \nabla_Z J \cdot B^T$ where $Z=AB$. $A$ is a $m \times n$ matrix and $B$ is a $n \times k$ matrix. The function $J$ is not given to me.
I began this proof by first writing $B^T$ as
$$ B^T = \begin{pmatrix} b_{00} & b_{10} & b_{20} \\ b_{01} & b_{11} & b_{21} \\ b_{02} & b_{12} & b_{22} \\ \end{pmatrix} $$
Since J is a loss function, it is scalar value. Computing its derivative wrt to Z i.e. $\nabla_Z J$, I write:
$$ \nabla_{Z} J = \begin{pmatrix} \frac{\partial L}{\partial z_{00}} & \frac{\partial L}{\partial z_{01}} & \frac{\partial L}{\partial z_{02}} \\ \frac{\partial L}{\partial z_{10}} & \frac{\partial L}{\partial z_{11}} & \frac{\partial L}{\partial z_{12}} \\ \frac{\partial L}{\partial z_{20}} & \frac{\partial L}{\partial z_{21}} & \frac{\partial L}{\partial z_{22}} \\ \end{pmatrix} $$
Multiplying these two matrices, I get:
$$ \begin{pmatrix} \frac{\partial L}{\partial z_{00}}b_{00} + \frac{\partial L}{\partial z_{01}}b_{01} + \frac{\partial L}{\partial z_{02}}b_{02}& \frac{\partial L}{\partial z_{00}}b_{10} + \frac{\partial L}{\partial z_{01}}b_{11} + \frac{\partial L}{\partial z_{02}}b_{12} & \frac{\partial L}{\partial z_{00}}b_{20} + \frac{\partial L}{\partial z_{01}}b_{21} + \frac{\partial L}{\partial z_{02}}b_{22} \\ \frac{\partial L}{\partial z_{10}}b_{00} + \frac{\partial L}{\partial z_{11}}b_{01} + \frac{\partial L}{\partial z_{12}}b_{02}& \frac{\partial L}{\partial z_{10}}b_{10} + \frac{\partial L}{\partial z_{11}}b_{11} + \frac{\partial L}{\partial z_{12}}b_{12} & \frac{\partial L}{\partial z_{10}}b_{20} + \frac{\partial L}{\partial z_{11}}b_{21} + \frac{\partial L}{\partial z_{12}}b_{22} \\ \frac{\partial L}{\partial z_{20}}b_{00} + \frac{\partial L}{\partial z_{21}}b_{01} + \frac{\partial L}{\partial z_{22}}b_{02}& \frac{\partial L}{\partial z_{20}}b_{10} + \frac{\partial L}{\partial z_{21}}b_{11} + \frac{\partial L}{\partial z_{22}}b_{12} & \frac{\partial L}{\partial z_{20}}b_{22} + \frac{\partial L}{\partial z_{21}}b_{20} + \frac{\partial L}{\partial z_{22}}b_{22} \\ \end{pmatrix} $$
At this point, I am stuck and I am not sure how to proceed. How do I take the next step? Note, I am not well versed in Matrix Math and it has been a while since I dealt with proves. I am eager to learn and thats why I'm asking for the next step - assuming I did it correctly so far!!