Comment by godelski

Comment by godelski 2 days ago

3 replies

       self.q = nn.Linear(embed_size, embed_size, bias = False)
       self.k = nn.Linear(embed_size, embed_size, bias = False)
       self.v = nn.Linear(embed_size, embed_size, bias = False)
Try

       self.qkv = nn.Linear(embed_size, 3*embed_size, bias = False)

    def forward(...):
       ...
       qkv = self.qkv(x)
jszymborski 2 days ago

This adds connections between the parameters of q, k, and v whereas the original doesn't, unless my very tired brain is missing something.

  • smus 2 days ago

    Nope, they all depend on x and the same is true in this scenario

  • godelski 2 days ago

    It is actually really common practice. It is a single linear layer so there's no connection intranodes. The reason to do this is because it is a bit less computationally intensive.

    tldr: linear layers have an associative property