答案 0 :(得分:2)
经过一番谷歌搜索后,我发现了尼古拉斯·伊万诺夫this code。诀窍是通过扩展RNNCell
抽象类来创建自己的单元包装类。
这是我对此的看法:
import tensorflow as tf
class DeviceCellWrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, cell, device):
self._cell = cell
self._device = device
@property
def state_size(self):
return self._cell.state_size
@property
def output_size(self):
return self._cell.output_size
def __call__(self, inputs, state, scope=None):
with tf.device(self._device):
return self._cell(inputs, state, scope)
然后你可以像所有其他包装一样使用这个包装器:
n_inputs = 5
n_outputs = 100
devices = ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3", "/gpu:4"]
n_steps = 20
X = tf.placeholder(tf.float32, shape=[None, n_steps, n_inputs])
lstm_cells = [DeviceCellWrapper(device, tf.nn.rnn_cell.BasicLSTMCell(
num_units=n_outputs, state_is_tuple=True))
for device in devices]
multi_layer_cell = tf.nn.rnn_cell.MultiRNNCell(lstm_cells, state_is_tuple=True)
outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)
答案 1 :(得分:1)
我们通常会看到两种不同的方法:或者像MiniQuark指出的那样包装BasicLSTMCell,或者使用不同的MultiRNNCell实现。包装BasicLSTMCell可能是您用例的更好选择。