在TensorFlow KMean类中解压缩的值太多

时间:2018-02-15 19:09:10

标签: python tensorflow

我目前正在使用KMeans模块中的tensorflow.contrib.factorization类。我的输入是(假设所有变量都已定义):

kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine', use_mini_batch=True)

我正在关注https://www.tensorflow.org/api_docs/python/tf/contrib/factorization/KMeans处的文档以解压缩值,如:

(all_scores, cluster_idx, scores, cluster_centers_initialized, init_op, train_op) = kmeans.training_graph()

我收到错误:

----> (all_scores, cluster_idx, scores, cluster_centers_initialized, init_op, train_op) = kmeans.training_graph()    
ValueError: too many values to unpack

我强烈猜测上述链接中的文档未更新,因为kmeans.training_graph()的输出为:

((<tf.Tensor 'sub_14:0' shape=(?, ?) dtype=float32>,),
 (<tf.Tensor 'Squeeze_7:0' shape=<unknown> dtype=int64>,),
 (<tf.Tensor 'Squeeze_6:0' shape=<unknown> dtype=float32>,),
 <tf.Variable 'initialized_3:0' shape=() dtype=bool_ref>,
 <tf.Variable 'clusters_3:0' shape=<unknown> dtype=float32_ref>,
 tf.Tensor 'cond_3/Merge:0' shape=() dtype=bool>,
 <tf.Operation 'group_deps_3' type=NoOp>)

请通过阅读文档告诉我我不了解的额外退货价值。

2 个答案:

答案 0 :(得分:1)

来自history in the repository

  

KMeans.training_graph()现在返回一个当前未使用的附加值。

如果你click on the link它会带你到源头并向你显示额外的退货项目。

return (all_scores, cluster_idx, scores, cluster_centers_initialized,
        init_op, training_op)
        cluster_centers_var, init_op, training_op)

cluster_centers_var是新项目。

答案 1 :(得分:1)

更新:从history of the file clustering_ops.py (master branch)看来,附加值(cluster_centers_vars)已在following commit from 6.Oct中删除,即在引入之后不久。

这意味着您的初始代码应该可以与TF的最新版本完美配合,即

(all_scores, cluster_idx, scores, cluster_centers_initialized, init_op, train_op) = kmeans.training_graph()

现在应该可以了。

但是,结果是您无法通过kmeans.training_graph()函数获得群集中心。

如果要获取群集中心,有两种解决方案。

第一个解决方案很简单,即使用KMeansClustering估计器,该估计器在文件kmeans.py中定义。更具体地说,您可以使用方法KMeansClustering.cluster_centers()

第二种解决方案是一种解决方法。如果您不使用KMeansClustering估计器,而仅使用文件clustering_ops.py中定义的KMeans图构造函数,那么仍然可以通过读取全局TF变量´clusters获得聚类中心: 0´:

tf_vble_cluster_centers = tf.global_variables('clusters:0')[0] # get the global TF variable 'clusters:0'
cluster_centers = sess.run(tf_vble_cluster_centers) # evaluate its contents
print(cluster_centers.shape) # nr. rows = nr. of clusters, nr. columns = nr. dimensions
print(cluster_centers[0]) # print cluster centers for the first cluster