Comment by canyon289
Comment by canyon289 15 days ago
I work at Google on these systems everyday (caveat this is my own words not my employers)). So I simultaneously can tell you that its smart people really thinking about every facet of the problem, and I can't tell you much more than that.
However I can share this written by my colleagues! You'll find great explanations about accelerator architectures and the considerations made to make things fast.
https://jax-ml.github.io/scaling-book/
In particular your questions are around inference which is the focus of this chapter https://jax-ml.github.io/scaling-book/inference/
Edit: Another great resource to look at is the unsloth guides. These folks are incredibly good at getting deep into various models and finding optimizations, and they're very good at writing it up. Here's the Gemma 3n guide, and you'll find others as well.
https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-...
Same explanation but with less mysticism:
Inference is (mostly) stateless. So unlike training where you need to have memory coherence over something like 100k machines and somehow avoid the certainty of machine failure, you just need to route mostly small amounts of data to a bunch of big machines.
I don't know what the specs of their inference machines are, but where I worked the machines research used were all 8gpu monsters. so long as your model fitted in (combined) vram, you could job was a goodun.
To scale the secret ingredient was industrial amounts of cash. Sure we had DGXs (fun fact, nvidia sent literal gold plated DGX machines) but they wernt dense, and were very expensive.
Most large companies have robust RPC, and orchestration, which means the hard part isn't routing the message, its making the model fit in the boxes you have. (thats not my area of expertise though)