Comment by godelski
self.q = nn.Linear(embed_size, embed_size, bias = False)
self.k = nn.Linear(embed_size, embed_size, bias = False)
self.v = nn.Linear(embed_size, embed_size, bias = False)
Try self.qkv = nn.Linear(embed_size, 3*embed_size, bias = False)
def forward(...):
...
qkv = self.qkv(x)
This adds connections between the parameters of q, k, and v whereas the original doesn't, unless my very tired brain is missing something.