Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) ->
self._distributed = distributed
self._stages_on_device = [stage for stage in self._stages if stage.mode.on_device]
self._stages_owned = [stage.mode.on_device and not stage.is_tied_weight_copy for stage in self._stages]
# Last stage index on this device, used to free batch data after last forward step.
self._last_stage_on_device = max(i for i, stage in enumerate(self._stages) if stage.mode.on_device)

# Setup the streams
self._compute_stream = self._get_current_stream()
Expand Down Expand Up @@ -199,9 +201,39 @@ def run_step(
log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"Beginning of the schedule steps", str))

# Run the steps according to the schedule
for step in schedule:
_prev_mb = -1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug code?

_mb_start_alloc = 0.0
for i, step in enumerate(schedule):
if step.type_ == StepType.forward and step.index != _prev_mb:
if _prev_mb >= 0:
mb_peak = torch.cuda.max_memory_allocated() / (1024**3)
mb_end_alloc = torch.cuda.memory_allocated() / (1024**3)
finished_mb = _prev_mb
log_pipeline_parallel_main_rank(
lambda: logger.info(
f"MB {finished_mb} done: per_mb_peak={mb_peak:.2f} GiB "
f"end_alloc={mb_end_alloc:.2f} GiB start_alloc={_mb_start_alloc:.2f} GiB"
)
)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
_mb_start_alloc = torch.cuda.memory_allocated() / (1024**3)
_prev_mb = step.index

self._train_step(context, step)

# Log final microbatch
if _prev_mb >= 0:
mb_peak = torch.cuda.max_memory_allocated() / (1024**3)
mb_end_alloc = torch.cuda.memory_allocated() / (1024**3)
finished_mb = _prev_mb
log_pipeline_parallel_main_rank(
lambda: logger.info(
f"MB {finished_mb} done: per_mb_peak={mb_peak:.2f} GiB "
f"end_alloc={mb_end_alloc:.2f} GiB start_alloc={_mb_start_alloc:.2f} GiB"
)
)

# Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak.
try:
next(context.data_iterator)
Expand Down Expand Up @@ -404,6 +436,9 @@ def _forward(self, context: BatchContext, step: Step) -> torch.Tensor | None:
)
if step.backward_step is not None:
context.contexts[step.backward_step.global_index] = grad_context
# Free batch data after last forward stage — no backward stage reads context.batch.
if step.stage == self._last_stage_on_device:
del context.batch[step.index]
self._record_compute(context, step)
return output

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _logits_loss_forward_backward(
all_reduce(total_loss, op=ReduceOp.SUM, group=self._parallel_dim.group)

if losses is not None:
losses[self._total_loss_name].append(total_loss)
losses[self._total_loss_name].append(total_loss.detach())

return total_loss, input_grad

Expand Down
5 changes: 4 additions & 1 deletion fast_llm/layers/language_model/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base
from fast_llm.functional.utils import reduce_losses
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss

Expand Down Expand Up @@ -45,6 +46,8 @@ def _forward_backward(
divisor=self._get_label_count(kwargs),
)

if new_logprobs_mean is not None:
new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch]
self._register_loss(
self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM
)
Expand All @@ -56,7 +59,7 @@ def get_loss_definitions(self) -> list[LossDef]:
def get_preprocessing_config(
self,
) -> dict[str, typing.Any]:
return {"use_grpo_data": True, "return_label_counts": True}
return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True}

@functools.cached_property
def _logprob_metric_name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/language_model/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _register_loss(
if self._sequence_parallel:
# TODO: Async
torch.distributed.all_reduce(value, op=reduce_op, group=self._parallel_dim.group)
losses[name].append(value)
losses[name].append(value.detach())

@property
def name(self) -> str:
Expand Down
Loading