我尝试过
weights = {
'wc1': tf.get_variable('wc1', shape=(8,8,4,32), initializer=tf.contrib.layers.xavier_initializer()),
'wc2': tf.get_variable('wc2', shape=(4,4,32,64), initializer=tf.contrib.layers.xavier_initializer()),
'wc3': tf.get_variable('wc3', shape=(3,3,64,64), initializer=tf.contrib.layers.xavier_initializer()),
'wd1': tf.get_variable('wd1', shape=(7744,512), initializer=tf.contrib.layers.xavier_initializer()),
'wd2': tf.get_variable('wd2', shape=(512,action_size), initializer=tf.contrib.layers.xavier_initializer()),
'bc1': tf.get_variable('bc1', shape=(32), initializer=tf.contrib.layers.xavier_initializer()),
'bc2': tf.get_variable('bc2', shape=(64), initializer=tf.contrib.layers.xavier_initializer()),
'bc3': tf.get_variable('bc3', shape=(64), initializer=tf.contrib.layers.xavier_initializer()),
'bd1': tf.get_variable('bd1', shape=(512), initializer=tf.contrib.layers.xavier_initializer()),
'bd2': tf.get_variable('bd2', shape=(action_size), initializer=tf.contrib.layers.xavier_initializer()),
}
然后
weight_copies = [tf.identity(weights) for x in range(10)]
但出现以下错误
TypeError:预期的二进制或Unicode字符串,得到{'wc1':,'wc2':,'wc3':,'wd1':,'wd2':,'bc1':,'bc2':,'bc3 ':,'bd1':,'bd2':}
现在我的问题是如何做到无误?
答案 0 :(得分:1)
weights
是一个字典,因此您需要遍历所有值。
weight_copies = [tf.identity(v) for v in weights.values()]
如果您也希望输出也成为字典,请对item
进行迭代。
weight_copies = {k: tf.identity(v) for k, v in weights.items()}