运行带有提取列表的张量流不起作用

时间:2016-07-22 12:58:43

标签: python tensorflow

我正在使用Tensorflow并实现了k均值聚类算法。一切都运行良好,但如果我想在list中使用几个提取来运行会话,我总是会收到错误,list无法转换为Tensor或{ {1}}。

documentation明确表示,我可以使用列表调用Operation。我做错了吗?

以下是源代码:

Session.run()

以下是错误消息:

import tensorflow as tf
import numpy as np

def tf_k_means(k, data, eps_=0.1):
    eps = tf.constant(eps_)

    cluster_means = tf.placeholder(tf.float32, [None, 2])
    tf_data = tf.placeholder(tf.float32, [None, 2], name='data')

    model = tf.initialize_all_variables()

    expanded_data = tf.expand_dims(tf_data, 0)
    expanded_means = tf.expand_dims(cluster_means, 1)
    distances = tf.reduce_sum(tf.square(tf.sub(expanded_means, expanded_data)), 2)
    mins = tf.to_int32(tf.argmin(distances, 0))

    clusters = tf.dynamic_partition(tf_data, mins, k)
    old_cluster_means = tf.identity(cluster_means)
    new_means = tf.concat(0, [tf.expand_dims(tf.reduce_mean(cluster, 0), 0) for cluster in clusters])

    clusters_moved = tf.reduce_sum(tf.square(tf.sub(old_cluster_means, new_means)), 1)
    converged = tf.reduce_all(tf.less(clusters_moved, eps))

    cms = data[np.random.randint(data.shape[0],size=k), :]

    with tf.Session() as sess:
        sess.run(model)
        conv = False
        while not conv:
            #####################################
            # THE FOLLOWING LINE DOES NOT WORK: #
            #####################################
            (cs, cms, conv) = sess.run([clusters, new_means, converged], 
                                        feed_dict={tf_data: data, cluster_means: cms})    

    return cs, cms

1 个答案:

答案 0 :(得分:2)

tf.dynamic_partition会返回list of Tensors,因此clusters本身就是一个列表。

clusters = tf.dynamic_partition(tf_data, mins, k)

当你将该列表提供给另一个列表中的sess.run时,我认为这就是你遇到问题的地方。你可以尝试一下:

sess.run(clusters + [new_means, converged], ...