Treinamento distribuido
Frameworks de Treinamento Distribuído
PyTorch
- *RL:*pytorch.org
- *ersão estável:*2.x (2024)
- *inguagem:*Python + C++/CUDA
- *adrão de facto*para pesquisa e produção de LLMs
Features Relevantes para LLMs
- *SDP2 (Fully Sharded Data Parallel):*Estratégia ZeRO-3 nativa; prefetching implícito
- *orch.compile:*Compilação JIT → 30–50% speedup automático
- *lexAttention:*API flexível para variantes de atenção
- *Tensor:*Primitiva de tensor distribuído (base do FSDP2, Tensor Parallel)
- *torch.distributed`:*NCCL, Gloo, MPI backends
JAX (Google)
- *RL:*github.comgooglejax
- *aradigma:*Funcional puro; arrays imutáveis; transformações (grad, jit, vmap, pmap)
- *LA (Accelerated Linear Algebra):*Compilador JIT para TPU e GPU
- *jit`:*Compila função para XLA → speedup automático
- *vmap`:*Vectorized map — batching automático de funções escalares
- *pmap`:*Parallel map — distribui computação em múltiplos dispositivos
*cossistema:*
- *lax:*Módulos de rede neural (NNX API, Linen)
- *ptax:*Optimizadores (Adam, AdaFactor, Lion)
- *rbax:*Checkpointing distribuído
- *rain:*Pipeline de dados para JAX
*uando usar:*Pesquisa Google; TPUs; workloads que se beneficiam de transformações funcionais
Keras (François Chollet)
- *RL:*keras.io · GitHub: keras-team/keras
- *riador:*François Chollet (ex
Google, hoje na Anthropic; mesmo criador do *RCAGI*— ver08-benchmarks/gerais-raciocinio.md) - *istória:*
- *eras 1.0 (2015):*API alto-nível para Theano (depois TensorFlow). Um dos motivos da adoção massiva de DL fora de pesquisa.
- *eras 2 (2017):*integrado ao TensorFlow como
tf.keras(default API do TF 2.x). - *eras 3 (2023, "Keras Core"):**ulti-backend*— mesma API roda sobre *ensorFlow, JAX e PyTorch* Reset estratégico após "Keras virou só wrapper de TF" perceived weakness.
- *aradigma:*Sequential / Functional / Subclassing API. Foco em ergonomia para beginner + production.
- *erasCV / KerasNLP / KerasHub:*componentes pré-treinados;
keras_hub.models.Llama3CausalLM.from_preset()interface tipo HuggingFace. - *erasTuner:*hyperparameter search.
*stado em LLM-land (2026):*
- *ão dominante.*PyTorch + HuggingFace é a stack hegemônica para LLM training (PEFT, TRL, Axolotl, Megatron
LM). JAX para TPU em Google. Keras 3 multibackend tenta retomar relevância via "escreva uma vez, escolha backend" mas adoção em LLMs frontier permanece marginal. - *nde Keras ainda brilha:*ensino, prototipagem rápida, modelos clássicos (CNN, RNN, transformers menores), edge deployment via TF Lite, integração com TensorBoard.
- *uando usar:*projetos pedagógicos; transição de DL clássico para LLM; quando você quer trocar de backend sem reescrever; integração com ecossistema TensorFlow legacy.
*or que está aqui:*referência cultural relevante (Chollet), e Keras 3 multi-backend é um exemplo de design que vale conhecer mesmo se você não usar diariamente.
DeepSpeed (Microsoft)
- *RL:*github.commicrosoftDeepSpeed
- *so:*ZeRO (123/Infinity), pipeline parallelism, activation checkpointing
- *ntegrações:*PyTorch nativo; HuggingFace
accelerate
ZeRO Optimizer (ver também 04-treinamento/pre-treino.md)
| Estágio | Particiona | Economia de Memória |
|---|---|---|
| ZeRO-1 | Optimizer states | ~4× |
| ZeRO-2 | + Gradients | ~8× |
| ZeRO-3 | + Parâmetros | N× (N = GPUs) |
| ZeRO-Infinity | + CPU/NVMe offload | Ilimitado* |
DeepSpeed Inference
- Kernel de atenção otimizado
- Quantização INT8 integrada
- *uando usar:*Se já usa DeepSpeed para treino; kernels customizados
FSDP — Fully Sharded Data Parallel (PyTorch Nativo)
- *lternativa ao ZeRO-3*sem dependência de DeepSpeed
- *SDP2 (2024):*API renovada;
fully_shard()por layer; melhor interoperabilidade comtorch.compile - *ocumentação:*pytorch.orgtutorialsintermediate/FSDP_tutorial.html
- *uando usar:*Treino de modelos 7B–70B em 2–16 GPUs; padrão para maioria dos times
Megatron-LM (NVIDIA)
- *RL:*github.comNVIDIAMegatron-LM
- *rigem:*NVIDIA Research (2019)
- *specialidade:*Tensor Parallelism + Pipeline Parallelism + Data Parallelism (3D parallelism)
- *erformance:*FlashAttention integrado; FP16BF16FP8; mixed precision
- *scala:*Modelos 1T+ parâmetros em clusters de 1000+ GPUs
*uando usar:*
- Pré-treino de modelos muito grandes (>70B parâmetros)
- Acesso a clusters NVIDIA (DGX, Selene)
- Reprodução de papers NVIDIA
ColossalAI
- *RL:*github.comhpcaitechColossalAI
- *oco:*Auto-parallelism; buscador automático de estratégia de paralelismo
- *eatures:*Sequência Parallelism, Tensor Parallelism, ZeRO alternativo
- *uando usar:*Experimentação com estratégias de paralelismo automático
Nanotron (HuggingFace)
- *RL:*github.comhuggingfacenanotron
- *oco:*Pré-treino minimalista e reprodutível
- *esign:*Simples de entender; Tensor + Pipeline + Data parallelism
- *so:*HuggingFace usa internamente; bom para pesquisa e reprodução
nanoGPT (Karpathy)
- *RL:*github.comkarpathynanoGPT
- *300 linhas de Python*— implementação minimal de GPT-2 treinável
- *alor:*Referência educacional; base para experimentos
- *lm.c:*Versão em C puro de Karpathy (2024) — GPT-2 em C sem dependências
LightSeq / Liger Kernel
LightSeq (ByteDance)
- Kernels CUDA otimizados para transformer (atenção, layer norm, embeddings)
- 1.5–2× speedup em fine-tuning
Liger Kernel (LinkedIn/Liger)
- *RL:*github.comlinkedinLiger-Kernel
- Kernels Triton para RMSNorm, RoPE, SwiGLU, CrossEntropy com chunking
- Reduz memória de ativações em 60%; compatível com HuggingFace
- Drop-in:
from liger_kernel.transformers import apply_liger_kernel_to_llama
Accelerate (HuggingFace)
- *RL:*github.comhuggingfaceaccelerate
- *apel:*Camada de abstração sobre FSDP, DeepSpeed, TPU, multi-GPU
- *so:*
Accelerator()→ mesmo código roda em 1 GPU, 8 GPUs, TPU - *ntegração:*TRL, Axolotl, LLaMA-Factory usam internamente
Stack de Pré-Treino Recomendado (2026)
| Escala | Framework | Paralelismo |
|---|---|---|
| 1–2 GPUs | PyTorch + FSDP2 | DP |
| 4–8 GPUs consumer | PyTorch + FSDP2 + DeepSpeed ZeRO-3 | DP + ZeRO |
| 8–64 GPUs A100/H100 | Megatron-LM ou Nanotron | TP + PP + DP |
| 100+ GPUs H100/B200 | Megatron-LM | 3D parallelism |
| TPU | JAX + Flax | pmap / mesh |
Monitoring de Treino
- *eights & Biases (wandb):*Padrão de facto; curvas de loss, LR, gradients
- *ensorBoard:*Integrado ao PyTorch; menor overhead
- *Lflow:*Open-source; experiment tracking; model registry
- *omet ML:*Alternativa ao wandb com foco enterprise
*ntegração mínima:*
import wandb
wandb.init(project="kode-pretraining")
wandb.log({"loss": loss, "lr": scheduler.get_last_lr()[0]})