pytorch如何异步更新参数?

背景是A3C算法,其中许多工作线程共享一个通用的网络参数,并共享一个通用的rmsprop状态,每个线程持有自己的分级参数。周期性地,每个工作线程都以无锁、异步方式使用通用的rmsprop状态来更新公共参数。
以前在Torch 7中,使用线程和optim库很容易做到这一点:
-- in main thread: shared parameters
params, _ = sharedNet:getParameters()

-- in worker thread: its own gradParameters
tNet = sharedNet:clone()
_, gradParams = tNet:getParameters()

-- in worker thread: stuff

-- in worker thread: updating shared parameters with its own gradParameters
function feval() return nil, gradParams end
optim.rmsprop(feval, params, sharedStates)
但是我没有看到一个很明显的方法来做同样的事情,因为现在参数和等级参数在nn.Parameter下被绑定在一起。有没有什么好的方法?
已邀请:
匿名用户

匿名用户

赞同来自:

已经在pytorch中实现了A3C,而且效果很好。当你得到一个副本时,在子过程中共享的所有内容都是为了打破梯度共享,并且像你正常的那样使用优化器:
for param in model.parameters():
param.grad.data = param.grad.data.clone()
这也包括在notes84中,你可以看一下。

要回复问题请先登录注册