diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b95d39463..52ee3778e 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -110,6 +110,8 @@ def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> self._distributed = distributed self._stages_on_device = [stage for stage in self._stages if stage.mode.on_device] self._stages_owned = [stage.mode.on_device and not stage.is_tied_weight_copy for stage in self._stages] + # Last stage index on this device, used to free batch data after last forward step. + self._last_stage_on_device = max(i for i, stage in enumerate(self._stages) if stage.mode.on_device) # Setup the streams self._compute_stream = self._get_current_stream() @@ -199,9 +201,39 @@ def run_step( log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"Beginning of the schedule steps", str)) # Run the steps according to the schedule - for step in schedule: + _prev_mb = -1 + _mb_start_alloc = 0.0 + for i, step in enumerate(schedule): + if step.type_ == StepType.forward and step.index != _prev_mb: + if _prev_mb >= 0: + mb_peak = torch.cuda.max_memory_allocated() / (1024**3) + mb_end_alloc = torch.cuda.memory_allocated() / (1024**3) + finished_mb = _prev_mb + log_pipeline_parallel_main_rank( + lambda: logger.info( + f"MB {finished_mb} done: per_mb_peak={mb_peak:.2f} GiB " + f"end_alloc={mb_end_alloc:.2f} GiB start_alloc={_mb_start_alloc:.2f} GiB" + ) + ) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + _mb_start_alloc = torch.cuda.memory_allocated() / (1024**3) + _prev_mb = step.index + self._train_step(context, step) + # Log final microbatch + if _prev_mb >= 0: + mb_peak = torch.cuda.max_memory_allocated() / (1024**3) + mb_end_alloc = torch.cuda.memory_allocated() / (1024**3) + finished_mb = _prev_mb + log_pipeline_parallel_main_rank( + lambda: logger.info( + f"MB {finished_mb} done: per_mb_peak={mb_peak:.2f} GiB " + f"end_alloc={mb_end_alloc:.2f} GiB start_alloc={_mb_start_alloc:.2f} GiB" + ) + ) + # Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak. try: next(context.data_iterator) @@ -404,6 +436,9 @@ def _forward(self, context: BatchContext, step: Step) -> torch.Tensor | None: ) if step.backward_step is not None: context.contexts[step.backward_step.global_index] = grad_context + # Free batch data after last forward stage — no backward stage reads context.batch. + if step.stage == self._last_stage_on_device: + del context.batch[step.index] self._record_compute(context, step) return output diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d57b465bf..95be18035 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -230,7 +230,7 @@ def _logits_loss_forward_backward( all_reduce(total_loss, op=ReduceOp.SUM, group=self._parallel_dim.group) if losses is not None: - losses[self._total_loss_name].append(total_loss) + losses[self._total_loss_name].append(total_loss.detach()) return total_loss, input_grad diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index a933fec99..cc6cbf726 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -7,6 +7,7 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -45,6 +46,8 @@ def _forward_backward( divisor=self._get_label_count(kwargs), ) + if new_logprobs_mean is not None: + new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] self._register_loss( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) @@ -56,7 +59,7 @@ def get_loss_definitions(self) -> list[LossDef]: def get_preprocessing_config( self, ) -> dict[str, typing.Any]: - return {"use_grpo_data": True, "return_label_counts": True} + return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} @functools.cached_property def _logprob_metric_name(self) -> str: diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 9a92661c9..3cab2bca8 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -83,7 +83,7 @@ def _register_loss( if self._sequence_parallel: # TODO: Async torch.distributed.all_reduce(value, op=reduce_op, group=self._parallel_dim.group) - losses[name].append(value) + losses[name].append(value.detach()) @property def name(self) -> str: