Replies: 4 comments 13 replies
-
That's actually really, really good question. I think you mean one can rewrite Correct? I think this would work but then the transformation becomes the same for keys and queries, and it would not be possible to distinguish them. So basically it looks like Separate: and merged: Come out as the same end result, the training dynamics would be different. Where in the first case the weight parameters are updated separately, and in the second case you lose that distinction and lose degrees of freedom. But you are welcome to try this in Chapter 5 for example and compare the training losses with and without the merging. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the reply. Yes, this is exactly what I mean. The training dynamics would be different, as there will be no keys K or queries Q any more. The thing is, do we really need them? Optimizing two matrices ( |
Beta Was this translation helpful? Give feedback.
-
Here is the experiment results using the Chapter 5 example:
class CA_alt1(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.S = nn.Linear(d_in, d_in, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
values = self.W_value(x)
attn_scores = self.S(x) @ x.transpose(1, 2)
attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(attn_scores / values.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
class MHA_alt1(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
head_dim = d_out // num_heads
self.heads = nn.ModuleList(
[CA_alt1(d_in, head_dim, context_length, dropout, qkv_bias) for _ in range(num_heads)]
)
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1) and in class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MHA_alt1( |
Beta Was this translation helpful? Give feedback.
-
This is such an interesting discussion! Standard Attention:
Proposed Covariance-Style Attention:
So, basically, while the standard attention mechanism explicitly constrains the rank through the projection dimension D, the covariance formulation allows for potentially richer attention patterns up to rank min(L, H). However, it's quite interesting now... maybe the low-rank constraint in standard attention is beneficial as an inductive bias during training, even though it limits the theoretical expressiveness. Would be very interested to hear your thoughts! |
Beta Was this translation helpful? Give feedback.
-
When compute the context vector in the attention algorithm, three weight matrices were introduced. It has discussed in #454 that the value matrix W_V is not necessary. For the rest two, query matrix and key matrix, keeping two of them seems not necessary, either. The context vector can be expressed as
X*W_q*W_K^T*X^T*X*W_V
where*
is for matrix multiplication. Is it possible to merge the partW_q*W_K^T
as a single covariance matrixS
, so the context vector becomeX*S*X^T*X*W_V
? This merge could potentially reduce nuisance parameters and improve computational performance.Beta Was this translation helpful? Give feedback.
All reactions