Asksia AI LOGO

Sia

lan8510's Question
Computer Science
Posted 10 months ago

(40 points) Backpropagation of matrix multiplication. The key operation in transformer is to calculate
the query-key multiplication matrix.
Forward input: Q E Rnxd, K E Rnxd Forward output: Z = QKT, where Z E Rnxn
Backward input:
aJ
a.J
OZin
:
oz
aJ
aJ
aJ
OZn1
0Zn2
aZnn
aJ
aJ
and
for i,j =
Find the backward outputs
n.
,
aQij
SKij
Sign in to unlock the answer
Answer from Sia
Posted 10 months ago
To solve the backpropagation problem for matrix multiplication in the context of transformers, we need to compute the gradients ∂J/∂Q and ∂J/∂K based on the forward output Z = QK^T.
The backward outputs can be derived using the chain rule of calculus.
The gradients can be computed as follows:


1. 
The forward output is given by Z=QKT Z = QK^T .

2. 
The Jacobian matrix for the backward input is JZ \frac{\partial J}{\partial Z} .

3. 
To find JQ \frac{\partial J}{\partial Q} , we use the relation: JQ=JZK \frac{\partial J}{\partial Q} = \frac{\partial J}{\partial Z} \cdot K where K K is the matrix of keys.

4. 
To find JK \frac{\partial J}{\partial K} , we use the relation: JK=QTJZ \frac{\partial J}{\partial K} = Q^T \cdot \frac{\partial J}{\partial Z} where QT Q^T is the transpose of the query matrix.
Therefore, the backward outputs are given by the above equations for i,j i, j ranging from 1 to n n .

Not the question you are looking for? Ask here!

Enter question by text

Enter question by image

Unlock Smarter Learning with AskSia Super!

Join Super, our all-in-one AI solution that can greatly improve your learning efficiency.

30% higher accuracy than GPT-4o
Entire learning journey support
The most student-friendly features
Study Other Question