primeqa.ir.dense.dpr_top.torch_util.transformer_optimize.TransformerOptimize#
- class primeqa.ir.dense.dpr_top.torch_util.transformer_optimize.TransformerOptimize(hypers: primeqa.ir.dense.dpr_top.torch_util.hypers_base.HypersBase, num_instances_to_train_over: int, model)#
Bases:
object
Collects standard steps to train transformer call step_loss after computing each loss
Methods
backward_on_loss
optimizer_report
optimizer_step
should_continue
- param loss
- step_loss(loss: torch.Tensor, **moving_averages)#
- Parameters
loss –
moving_averages –
- Returns
the value of the loss (float)