在pytorch中使用torch.nn.parallel.DistributedDataParallel进行分布式训练时,需要使用torch.distributed.init_process_group()初始化torch.nn.parallel.DistributedDataParallel包。

1 torch.distributed.init_process_group

1. 函数形式

torch.distributed.init_process_group(backend, init_method=None, timeout=datetime.timedelta(seconds=1800), world_size=-1, rank=-1, store=None, group_name='')

2. 函数功能

初始化默认的分布式进程组,同时初始化分布式包。

3. 函数参数

  • backend:类型为str或者Backend,必需参数。所使用的Backend。可选值为mpi,gloo,nccl。我们可以通过小写的字符串如"gloo"指定该字段的值,也可以通过Backend.GLOO设置该字段的值;
  • init_method:类型为str,可选参数。指定如何初始化进程组的url。如果没有指定参数init_method或者store,则使用默认值env://
  • world_size:类型为int,可选参数。参与分布式训练的进程数。如果指定了参数store,则为必需参数。默认值为-1;
  • rank:类型为int,可选参数。当前进程的序号(范围为0到world_size-1)之间。如果指定了参数store,则为必需参数。默认值为-1;
  • store:类型为Store,可选参数。所有的工作进程都可以访问的键/值store,用于交换连接/地址信息;
  • group_name:类型为str,可选参数。进程组名称;
  • timeout:类型为timedelta,可选参数。对进程组指定的操作超时,默认值为30分钟;