TypeError: scatter() missing 1 required position argument 'scatter_list'

大家好,最近我试了一下pytorch的mpi分布式程序,但是一直在scatter函数的参数上报错。程序如下:
 
def run(rank, size, hostname):
print("I am {} of {} in {}".format(rank, size, hostname))
tensor = torch.zeros(1)
group=dist.new_group([0,1,2])
if rank == 0:
scatter_list=[torch.zeros(1) for _ in range(3)]
dist.scatter(tensor= tensor, src=0, scatter_list=scatter_list, group=group)
print("Master has completed Scatter")
else:
tensor += 1
dist.scatter(tensor= tensor, src=0, group=group)
print("worker has completed scatter")
print('Rank', rank, 'has data', tensor[0])

def init_process(rank, size, hostname, fn, backend='tcp'):
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size,hostname)

if __name__ == "__main__":
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
hostname = socket.gethostname()
p = Process(target = init_process,
args=(world_rank, world_size, hostname, run, 'mpi'))
p.start()
p.join()
然后,通过 mpirun -np 3 python test.py 来运行,一直报错如下:
 
  File "mpi_test.py", line 17, in run
dist.scatter(tensor= tensor, src=0, group=group)
TypeError: scatter() missing 1 required positional argument: 'scatter_list'

但是pytorch对于rank>0的process不能添加scatter_list,很奇怪。 而且我添加scatter_list参数之后也会报错:
non-empty can be given only to scatter source
所以,,我已经迷糊了,不知道该怎么做,有谁了解这个错误的原因吗?
已邀请:

要回复问题请先登录注册