将tf.truncated_normal()列表或字典列表馈入到Tensorflow模型中

时间:2018-07-09 10:25:54

标签: python-3.x tensorflow

我是tensorflow的新手,我正在尝试学习如何有效地使用该工具。

我在下面的问题上进行扩展,但这是tldr:

我想知道哪种最佳方法是使用feed_dict将以下权重和偏差输入到我的模型中:

def generate_initial_population(my_population_size):
    my_weights = []
    my_biases = []
    for _ in range(my_population_size):
        my_weights.append({
            'h1': tf.Variable(tf.truncated_normal([n_inputs, n_hidden_1])),
            'h2': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2])),
            'out': tf.Variable(tf.truncated_normal([n_hidden_2, n_class]))
        })

        my_biases.append({
            'b1': tf.Variable(tf.truncated_normal([n_hidden_1])),
            'b2': tf.Variable(tf.truncated_normal([n_hidden_2])),
            'out': tf.Variable(tf.truncated_normal([n_class]))
        })
    return my_weights, my_biases


weights, biases = generate_initial_population(population_size)

我不能简单地使用feed_dict={weights_ph: weights},因为它会产生错误。我不知道如何有效地解决这个问题

最后检查代码可能有助于理解我在说什么。

我想知道是否有什么办法可以向我的模型提供一个包含tf.truncated_normals的列表。 我收到ValueError: setting an array element with a sequence.错误,因为我认为它正在尝试转换为np.array,但尺寸存在问题

我找到了一种简单的锻炼方法,在该方法中,我首先使用会话运行计算出所有张量的值,然后将其输入到我的模型中。 如果这是正确的解决方案,我会感到困惑,因为我倾向于认为它会比较慢,因为您必须执行两次会话?

但是,如果我的原始列表形状不完美,此解决方案也不起作用 像[ [1, [1,2]]]或我的截尾法线形状不一样

我当时想我只是将我怪异的形状列表输入到模型中,然后使用tf.gather来获取要处理的特定索引。

由于我无法做到这一点,因此这是解决此问题的正确方法……只需先计算出truncated_normals,然后将其输入模型即可。然后根据需要在模型内部重塑列表?

我也有一个非常相似的问题,因为我也想将字典列表也输入到模型中。处理该问题的正确方法是从字典中提取数据,然后仅分别输入每个键的每个值。

我正在尝试学习,却无法在其他地方找到此信息

这是我设计的代码片段,目的是无法解释我的意思

import tensorflow as tf

list_ph = tf.placeholder(dtype=tf.float32)
index_ph = tf.placeholder(dtype=tf.int32)


def model(my_list, index):
    value = tf.gather(my_list, index, axis=0)
    return value



my_model = model(list_ph, index_ph)

with tf.Session() as sess:
    var_list = []

    truncated_normal = tf.Variable(tf.truncated_normal(shape=[5, 3]))

    for i in range(4):
        var_list.append(truncated_normal)
    # for i in range(4):
    #     var_list.append({i: i*2})


    sess.run(tf.global_variables_initializer())
    #will work but will not work for dictionaries
    val = sess.run(var_list)
    # will not work, but will work if you feed val
    var = sess.run(my_model, feed_dict={list_ph: var_list, index_ph: 1}) 

0 个答案:

没有答案