torch_utils#

Functions

all_gather

all gather the tensor with dimensions [d0 x d1 x...], returning a tensor with dimensions [d0*world_size x d1 x...] :param tensor: this process's tensor :return:

reduce

all reduce the tensor, modifying the tensor :param tensor: the tensor that will be all-reduced :param op: operation to reduce (example: torch.distributed.ReduceOp.SUM) :param check_id: identifier for this call to all_reduce (to check that there is no cross talk) :return:

to_tensor