primeqa.ir.dense.dpr_top.torch_util.transformer_optimize.TransformerOptimize#

class primeqa.ir.dense.dpr_top.torch_util.transformer_optimize.TransformerOptimize(hypers: primeqa.ir.dense.dpr_top.torch_util.hypers_base.HypersBase, num_instances_to_train_over: int, model)#

Bases: object

Collects standard steps to train transformer call step_loss after computing each loss

Methods

backward_on_loss

optimizer_report

optimizer_step

should_continue

step_loss

param loss

step_loss(loss: torch.Tensor, **moving_averages)#
Parameters
  • loss

  • moving_averages

Returns

the value of the loss (float)