我想在一些时间序列数据上运行GRU单元格,以根据最后一层中的激活对它们进行聚类。我对GRU单元实现做了一个小改动
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope(scope or type(self).__name__): # "GRUCell"
with vs.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = array_ops.split(1, 2, linear([inputs, state], 2 * self._num_units, True, 1.0))
r, u = sigmoid(r), sigmoid(u)
with vs.variable_scope("Candidate"):
c = tanh(linear([inputs, r * state], self._num_units, True))
new_h = u * state + (1 - u) * c
# store the activations, everything else is the same
self.activations = [r,u,c]
return new_h, new_h
在此之后,我按照以下方式连接激活,然后在调用此GRU单元格的脚本中返回它们
@property
def activations(self):
return self._activations
@activations.setter
def activations(self, activations_array):
print "PRINT THIS"
concactivations = tf.concat(concat_dim=0, values=activations_array, name='concat_activations')
self._activations = tf.reshape(tensor=concactivations, shape=[-1], name='flatten_activations')
我以下列方式调用GRU单元
outputs, state = rnn.rnn(cell=cell, inputs=x, initial_state=initial_state, sequence_length=s)
其中s
是一个批处理长度数组,其中包含输入批处理的每个元素中的时间戳数。
最后我使用
获取fetched = sess.run(fetches=cell.activations, feed_dict=feed_dict)
执行时我收到以下错误
追踪(最近一次通话): 文件“xxx.py”,第162行,in fetched = sess.run(fetches = cell.activations,feed_dict = feed_dict) 文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第315行,在运行中 return self._run(None,fetches,feed_dict) 在_run中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第511行 feed_dict_string) 在_do_run中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第564行 target_list) 在_do_call中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第588行 six.reraise(e_type,e_value,e_traceback) 在_do_call中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第571行 return fn(* args) 在_run_fn中输入文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第555行
返回tf_session.TF_Run(session,feed_dict,fetch_list,target_list) tensorflow.python.pywrap_tensorflow.StatusNotOK:无效参数:为RNN / cond_396 / ClusterableGRUCell / flatten_activations返回的张量:0无效。
有人可以通过传递可变长度序列来了解如何在最后一步从GRU单元格获取激活吗?感谢。
答案 0 :(得分:0)
要从最后一步获取激活,您希望激活成为您所在州的一部分,由tf.rnn返回。