#P4498. 交叉注意力

交叉注意力

题目描述

给定 Query 序列

XQRTQ×dmodelX_Q \in \mathbb{R}^{T_Q \times d_{\text{model}}}

和 Key/Value 序列

XKRTK×dmodelX_K \in \mathbb{R}^{T_K \times d_{\text{model}}}

请你实现交叉注意力(Cross Attention)的完整计算流程。

1. 线性变换

Q=XQWQ+bQ,K=XKWK+bK,V=XKWV+bVQ = X_Q W_Q + b_Q,\quad K = X_K W_K + b_K,\quad V = X_K W_V + b_V

2. 计算注意力分数

S=QKdkS = \frac{Q K^\top}{\sqrt{d_k}}

3. 对每个 Query 行进行 softmax

A=softmax(S)A = \text{softmax}(S)

4. 得到加权输出

H=AVH = A V

5. 输出层线性变换

O=HWO+bOO = H W_O + b_O

输入参数

  • X_Q:形状为 TQ×dmodelT_Q \times d_{\text{model}} 的 Query 序列
  • X_K:形状为 TK×dmodelT_K \times d_{\text{model}} 的 Key/Value 序列
  • W_Q, W_K, W_V:形状为 dmodel×dmodeld_{\text{model}} \times d_{\text{model}} 的线性投影矩阵
  • b_Q, b_K, b_V:长度为 dmodeld_{\text{model}} 的偏置向量
  • W_O:形状为 dmodel×dmodeld_{\text{model}} \times d_{\text{model}} 的输出层权重
  • b_O:长度为 dmodeld_{\text{model}} 的输出偏置

返回值

  • O:形状为 TQ×dmodelT_Q \times d_{\text{model}} 的最终交叉注意力输出矩阵

示例

输入:

X_Q =
[[1, 0, 1, 0],
 [0, 1, 0, 1]]

X_K =
[[1, 1, 0, 0],
 [0, 1, 1, 0]]

W_Q =
[[1, 0, 1, 0],
 [0, 1, 0, 1],
 [1, 0, 0, 1],
 [0, 1, 1, 0]]

b_Q = [0, 0, 0, 0]

W_K =
[[1, 1, 0, 0],
 [0, 1, 1, 0],
 [1, 0, 0, 1],
 [0, 0, 1, 1]]

b_K = [0, 0, 0, 0]

W_V =
[[1, 0, 0, 1],
 [1, 1, 0, 0],
 [0, 1, 1, 0],
 [0, 0, 1, 1]]

b_V = [0, 0, 0, 0]

W_O =
[[1, 0, 1, 0],
 [0, 1, 0, 1],
 [1, 0, 0, 1],
 [0, 1, 1, 0]]

b_O = [0, 0, 0, 0]

输出:

O =
[[2.00, 2.00, 1.76, 2.24],
 [2.00, 2.00, 2.24, 1.76]]

提示

  • 输入序列范围: 1000XQ[i,j], XK[i,j]1000-1000 \le X_Q[i,j],\ X_K[i,j] \le 1000
  • 权重矩阵范围: 10WQ, WK, WV, WO10-10 \le W_Q,\ W_K,\ W_V,\ W_O \le 10
  • 偏置项范围: 5bQ, bK, bV, bO5-5 \le b_Q,\ b_K,\ b_V,\ b_O \le 5
  • softmax 输出满足: 0A[i,j]10 \le A[i,j] \le 1 jA[i,j]=1\sum_{j} A[i,j] = 1
  • 输出 OO 可为任意实数矩阵