From bf20573f60178618f6e498fe182755be49e72e46 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 16 Apr 2026 09:04:23 +0000 Subject: [PATCH 1/2] Fix GPU memory leak: detach loss tensors to prevent autograd graph retention across microbatches Loss scalars stored in context.losses retained FunctionBackward grad_fn references from wrap_forward_backward, keeping C++ autograd nodes and their CUDA tensor references alive across all microbatches. This caused ~164 MiB/microbatch growth, leading to OOM with depth_first_micro_batches>=128. Fix: .detach() total_loss in head.py and individual losses in loss.py before appending to the losses dict. These values are only used for logging (reduced to .item() at step end), so detaching is safe. Also adds per-microbatch memory logging in the schedule runner and frees batch data after the last forward stage. --- fast_llm/engine/schedule/runner.py | 37 ++++++++++++++++++++- fast_llm/layers/language_model/head.py | 2 +- fast_llm/layers/language_model/loss/loss.py | 2 +- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b95d39463..52ee3778e 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -110,6 +110,8 @@ def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> self._distributed = distributed self._stages_on_device = [stage for stage in self._stages if stage.mode.on_device] self._stages_owned = [stage.mode.on_device and not stage.is_tied_weight_copy for stage in self._stages] + # Last stage index on this device, used to free batch data after last forward step. + self._last_stage_on_device = max(i for i, stage in enumerate(self._stages) if stage.mode.on_device) # Setup the streams self._compute_stream = self._get_current_stream() @@ -199,9 +201,39 @@ def run_step( log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"Beginning of the schedule steps", str)) # Run the steps according to the schedule - for step in schedule: + _prev_mb = -1 + _mb_start_alloc = 0.0 + for i, step in enumerate(schedule): + if step.type_ == StepType.forward and step.index != _prev_mb: + if _prev_mb >= 0: + mb_peak = torch.cuda.max_memory_allocated() / (1024**3) + mb_end_alloc = torch.cuda.memory_allocated() / (1024**3) + finished_mb = _prev_mb + log_pipeline_parallel_main_rank( + lambda: logger.info( + f"MB {finished_mb} done: per_mb_peak={mb_peak:.2f} GiB " + f"end_alloc={mb_end_alloc:.2f} GiB start_alloc={_mb_start_alloc:.2f} GiB" + ) + ) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + _mb_start_alloc = torch.cuda.memory_allocated() / (1024**3) + _prev_mb = step.index + self._train_step(context, step) + # Log final microbatch + if _prev_mb >= 0: + mb_peak = torch.cuda.max_memory_allocated() / (1024**3) + mb_end_alloc = torch.cuda.memory_allocated() / (1024**3) + finished_mb = _prev_mb + log_pipeline_parallel_main_rank( + lambda: logger.info( + f"MB {finished_mb} done: per_mb_peak={mb_peak:.2f} GiB " + f"end_alloc={mb_end_alloc:.2f} GiB start_alloc={_mb_start_alloc:.2f} GiB" + ) + ) + # Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak. try: next(context.data_iterator) @@ -404,6 +436,9 @@ def _forward(self, context: BatchContext, step: Step) -> torch.Tensor | None: ) if step.backward_step is not None: context.contexts[step.backward_step.global_index] = grad_context + # Free batch data after last forward stage — no backward stage reads context.batch. + if step.stage == self._last_stage_on_device: + del context.batch[step.index] self._record_compute(context, step) return output diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d57b465bf..95be18035 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -230,7 +230,7 @@ def _logits_loss_forward_backward( all_reduce(total_loss, op=ReduceOp.SUM, group=self._parallel_dim.group) if losses is not None: - losses[self._total_loss_name].append(total_loss) + losses[self._total_loss_name].append(total_loss.detach()) return total_loss, input_grad diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 9a92661c9..3cab2bca8 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -83,7 +83,7 @@ def _register_loss( if self._sequence_parallel: # TODO: Async torch.distributed.all_reduce(value, op=reduce_op, group=self._parallel_dim.group) - losses[name].append(value) + losses[name].append(value.detach()) @property def name(self) -> str: From d79d75ec43f218e8aea984bb0da06fb7cb375b62 Mon Sep 17 00:00:00 2001 From: bigximik Date: Thu, 16 Apr 2026 10:31:29 +0000 Subject: [PATCH 2/2] normalize grpo_new_logprobs metric by document count The new_logprobs_mean metric was a raw sum that scaled linearly with batch size (depth_first_micro_batches), making it incomparable across different configurations. Divide by num_documents_in_batch so the logged metric represents mean log-prob per document. - Add return_document_count to get_preprocessing_config() so the data pipeline counts documents automatically - Divide new_logprobs_mean by num_documents_in_batch before registering --- fast_llm/layers/language_model/loss/grpo.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index a933fec99..cc6cbf726 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -7,6 +7,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -45,6 +46,8 @@ def _forward_backward( divisor=self._get_label_count(kwargs), ) + if new_logprobs_mean is not None: + new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] self._register_loss( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) @@ -56,7 +59,7 @@ def get_loss_definitions(self) -> list[LossDef]: def get_preprocessing_config( self, ) -> dict[str, typing.Any]: - return {"use_grpo_data": True, "return_label_counts": True} + return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} @functools.cached_property def _logprob_metric_name(self) -> str: