Comment by jmward01
To get a similar token/sec in training though you would need to swap batch size and seq length so you could have the massive batch size but then won't you start hitting memory issues with any reasonable sequence length? You would have to create do something similar to a minibatch along the sequence and cut the gradients after a short number of tokens on each sequence. So how will they learn truly long sequences for recall? Or is there a different trick I am missing here?