Comment by jmward01

Comment by jmward01 2 hours ago

5 replies

RNNs have two huge issues: - long context. Recurrence degrades the signal for the same reason that 'deep' nn architectures don't go much past 3-4 layers before you need residual connections and the like - (this is the big one) training performance is terrible since you can't parallelize them across a sequence like you can with causal masked attn in transformers

On the huge benefit side though you get: - guaranteed state size so perfect batch packing, perfect memory use, easy load/unload from a batch, O(1) of token gen so generally massive performance gains in inference. - unlimited context (well, no need for a concept of a position embedding or similar system)

Taking the best of both worlds is definitely where it is at for the future. An architecture that can train parallelized, has a fixed state size so you can load/unload and patch batches perfectly, unlimited context (with perfect recall), etc etc. That is the real architecture to go for.

zozbot234 2 hours ago

RNN training cannot be parallelized along the sequence dimension like attention can, but it can still be trained in batches on multiple sequences simultaneously. Given the sizes of modern training sets and the limits on context size for transformer-based models, it's not clear to what extent this is an important limitation nowadays. It may have been more relevant in the early days of attention-based models where being able to do experimental training runs quickly on relatively small sizes of training data may have been important.

  • DeveloperErrata 34 minutes ago

    Not quite, most of the recent work on modern RNNs has been addressing this exact limitation. For instance linear attention yields formulations that can be equivalently interpreted either as a parallel operation or a recursive one. The consequence is that these parallelizable versions of RNNs are often "less expressive per param" than their old-school non-parallelizable RNN counterparts, though you could argue that they make up for that in practice by being more powerful per unit of training compute via much better training efficiency.

  • jmward01 an hour ago

    To get a similar token/sec in training though you would need to swap batch size and seq length so you could have the massive batch size but then won't you start hitting memory issues with any reasonable sequence length? You would have to create do something similar to a minibatch along the sequence and cut the gradients after a short number of tokens on each sequence. So how will they learn truly long sequences for recall? Or is there a different trick I am missing here?

cs702 an hour ago

Linear RNNs overcome both issues. All the RNNs I mentioned are linear RNNs.