Treinamento distribuido

Frameworks de Treinamento Distribuído

PyTorch

  • *RL:*pytorch.org
  • *ersão estável:*2.x (2024)
  • *inguagem:*Python + C++/CUDA
  • *adrão de facto*para pesquisa e produção de LLMs

Features Relevantes para LLMs

  • *SDP2 (Fully Sharded Data Parallel):*Estratégia ZeRO-3 nativa; prefetching implícito
  • *orch.compile:*Compilação JIT → 30–50% speedup automático
  • *lexAttention:*API flexível para variantes de atenção
  • *Tensor:*Primitiva de tensor distribuído (base do FSDP2, Tensor Parallel)
  • *torch.distributed`:*NCCL, Gloo, MPI backends

JAX (Google)

  • *RL:*github.comgooglejax
  • *aradigma:*Funcional puro; arrays imutáveis; transformações (grad, jit, vmap, pmap)
  • *LA (Accelerated Linear Algebra):*Compilador JIT para TPU e GPU
  • *jit`:*Compila função para XLA → speedup automático
  • *vmap`:*Vectorized map — batching automático de funções escalares
  • *pmap`:*Parallel map — distribui computação em múltiplos dispositivos

*cossistema:*

  • *lax:*Módulos de rede neural (NNX API, Linen)
  • *ptax:*Optimizadores (Adam, AdaFactor, Lion)
  • *rbax:*Checkpointing distribuído
  • *rain:*Pipeline de dados para JAX

*uando usar:*Pesquisa Google; TPUs; workloads que se beneficiam de transformações funcionais


Keras (François Chollet)

  • *RL:*keras.io · GitHub: keras-team/keras
  • *riador:*François Chollet (exGoogle, hoje na Anthropic; mesmo criador do *RCAGI*— ver 08-benchmarks/gerais-raciocinio.md)
  • *istória:*
    • *eras 1.0 (2015):*API alto-nível para Theano (depois TensorFlow). Um dos motivos da adoção massiva de DL fora de pesquisa.
    • *eras 2 (2017):*integrado ao TensorFlow como tf.keras (default API do TF 2.x).
    • *eras 3 (2023, "Keras Core"):**ulti-backend*— mesma API roda sobre *ensorFlow, JAX e PyTorch* Reset estratégico após "Keras virou só wrapper de TF" perceived weakness.
  • *aradigma:*Sequential / Functional / Subclassing API. Foco em ergonomia para beginner + production.
  • *erasCV / KerasNLP / KerasHub:*componentes pré-treinados; keras_hub.models.Llama3CausalLM.from_preset() interface tipo HuggingFace.
  • *erasTuner:*hyperparameter search.

*stado em LLM-land (2026):*

  • *ão dominante.*PyTorch + HuggingFace é a stack hegemônica para LLM training (PEFT, TRL, Axolotl, MegatronLM). JAX para TPU em Google. Keras 3 multibackend tenta retomar relevância via "escreva uma vez, escolha backend" mas adoção em LLMs frontier permanece marginal.
  • *nde Keras ainda brilha:*ensino, prototipagem rápida, modelos clássicos (CNN, RNN, transformers menores), edge deployment via TF Lite, integração com TensorBoard.
  • *uando usar:*projetos pedagógicos; transição de DL clássico para LLM; quando você quer trocar de backend sem reescrever; integração com ecossistema TensorFlow legacy.

*or que está aqui:*referência cultural relevante (Chollet), e Keras 3 multi-backend é um exemplo de design que vale conhecer mesmo se você não usar diariamente.


DeepSpeed (Microsoft)

  • *RL:*github.commicrosoftDeepSpeed
  • *so:*ZeRO (123/Infinity), pipeline parallelism, activation checkpointing
  • *ntegrações:*PyTorch nativo; HuggingFace accelerate

ZeRO Optimizer (ver também 04-treinamento/pre-treino.md)

Estágio Particiona Economia de Memória
ZeRO-1 Optimizer states ~4×
ZeRO-2 + Gradients ~8×
ZeRO-3 + Parâmetros N× (N = GPUs)
ZeRO-Infinity + CPU/NVMe offload Ilimitado*

DeepSpeed Inference

  • Kernel de atenção otimizado
  • Quantização INT8 integrada
  • *uando usar:*Se já usa DeepSpeed para treino; kernels customizados

FSDP — Fully Sharded Data Parallel (PyTorch Nativo)

  • *lternativa ao ZeRO-3*sem dependência de DeepSpeed
  • *SDP2 (2024):*API renovada; fully_shard() por layer; melhor interoperabilidade com torch.compile
  • *ocumentação:*pytorch.orgtutorialsintermediate/FSDP_tutorial.html
  • *uando usar:*Treino de modelos 7B–70B em 2–16 GPUs; padrão para maioria dos times

Megatron-LM (NVIDIA)

  • *RL:*github.comNVIDIAMegatron-LM
  • *rigem:*NVIDIA Research (2019)
  • *specialidade:*Tensor Parallelism + Pipeline Parallelism + Data Parallelism (3D parallelism)
  • *erformance:*FlashAttention integrado; FP16BF16FP8; mixed precision
  • *scala:*Modelos 1T+ parâmetros em clusters de 1000+ GPUs

*uando usar:*

  • Pré-treino de modelos muito grandes (>70B parâmetros)
  • Acesso a clusters NVIDIA (DGX, Selene)
  • Reprodução de papers NVIDIA

ColossalAI

  • *RL:*github.comhpcaitechColossalAI
  • *oco:*Auto-parallelism; buscador automático de estratégia de paralelismo
  • *eatures:*Sequência Parallelism, Tensor Parallelism, ZeRO alternativo
  • *uando usar:*Experimentação com estratégias de paralelismo automático

Nanotron (HuggingFace)

  • *RL:*github.comhuggingfacenanotron
  • *oco:*Pré-treino minimalista e reprodutível
  • *esign:*Simples de entender; Tensor + Pipeline + Data parallelism
  • *so:*HuggingFace usa internamente; bom para pesquisa e reprodução

nanoGPT (Karpathy)

  • *RL:*github.comkarpathynanoGPT
  • *300 linhas de Python*— implementação minimal de GPT-2 treinável
  • *alor:*Referência educacional; base para experimentos
  • *lm.c:*Versão em C puro de Karpathy (2024) — GPT-2 em C sem dependências

LightSeq / Liger Kernel

LightSeq (ByteDance)

  • Kernels CUDA otimizados para transformer (atenção, layer norm, embeddings)
  • 1.5–2× speedup em fine-tuning

Liger Kernel (LinkedIn/Liger)

  • *RL:*github.comlinkedinLiger-Kernel
  • Kernels Triton para RMSNorm, RoPE, SwiGLU, CrossEntropy com chunking
  • Reduz memória de ativações em 60%; compatível com HuggingFace
  • Drop-in: from liger_kernel.transformers import apply_liger_kernel_to_llama

Accelerate (HuggingFace)

  • *RL:*github.comhuggingfaceaccelerate
  • *apel:*Camada de abstração sobre FSDP, DeepSpeed, TPU, multi-GPU
  • *so:*Accelerator() → mesmo código roda em 1 GPU, 8 GPUs, TPU
  • *ntegração:*TRL, Axolotl, LLaMA-Factory usam internamente

Stack de Pré-Treino Recomendado (2026)

Escala Framework Paralelismo
1–2 GPUs PyTorch + FSDP2 DP
4–8 GPUs consumer PyTorch + FSDP2 + DeepSpeed ZeRO-3 DP + ZeRO
8–64 GPUs A100/H100 Megatron-LM ou Nanotron TP + PP + DP
100+ GPUs H100/B200 Megatron-LM 3D parallelism
TPU JAX + Flax pmap / mesh

Monitoring de Treino

  • *eights & Biases (wandb):*Padrão de facto; curvas de loss, LR, gradients
  • *ensorBoard:*Integrado ao PyTorch; menor overhead
  • *Lflow:*Open-source; experiment tracking; model registry
  • *omet ML:*Alternativa ao wandb com foco enterprise

*ntegração mínima:*

import wandb
wandb.init(project="kode-pretraining")
wandb.log({"loss": loss, "lr": scheduler.get_last_lr()[0]})

Source: ../home/koder/dev/koder/meta/docs/ia/compendium/07-frameworks/treinamento-distribuido.md