我目前正在使用KMeans
模块中的tensorflow.contrib.factorization
类。我的输入是(假设所有变量都已定义):
kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine', use_mini_batch=True)
我正在关注https://www.tensorflow.org/api_docs/python/tf/contrib/factorization/KMeans处的文档以解压缩值,如:
(all_scores, cluster_idx, scores, cluster_centers_initialized, init_op, train_op) = kmeans.training_graph()
我收到错误:
----> (all_scores, cluster_idx, scores, cluster_centers_initialized, init_op, train_op) = kmeans.training_graph()
ValueError: too many values to unpack
我强烈猜测上述链接中的文档未更新,因为kmeans.training_graph()
的输出为:
((<tf.Tensor 'sub_14:0' shape=(?, ?) dtype=float32>,),
(<tf.Tensor 'Squeeze_7:0' shape=<unknown> dtype=int64>,),
(<tf.Tensor 'Squeeze_6:0' shape=<unknown> dtype=float32>,),
<tf.Variable 'initialized_3:0' shape=() dtype=bool_ref>,
<tf.Variable 'clusters_3:0' shape=<unknown> dtype=float32_ref>,
tf.Tensor 'cond_3/Merge:0' shape=() dtype=bool>,
<tf.Operation 'group_deps_3' type=NoOp>)
请通过阅读文档告诉我我不了解的额外退货价值。
答案 0 :(得分:1)
KMeans.training_graph()现在返回一个当前未使用的附加值。
如果你click on the link它会带你到源头并向你显示额外的退货项目。
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
init_op, training_op)
cluster_centers_var, init_op, training_op)
cluster_centers_var
是新项目。
答案 1 :(得分:1)
更新:从history of the file clustering_ops.py (master branch)看来,附加值(cluster_centers_vars
)已在following commit from 6.Oct中删除,即在引入之后不久。
这意味着您的初始代码应该可以与TF的最新版本完美配合,即
(all_scores, cluster_idx, scores, cluster_centers_initialized, init_op, train_op) = kmeans.training_graph()
现在应该可以了。
但是,结果是您无法通过kmeans.training_graph()
函数获得群集中心。
如果要获取群集中心,有两种解决方案。
第一个解决方案很简单,即使用KMeansClustering
估计器,该估计器在文件kmeans.py中定义。更具体地说,您可以使用方法KMeansClustering.cluster_centers()
。
第二种解决方案是一种解决方法。如果您不使用KMeansClustering
估计器,而仅使用文件clustering_ops.py中定义的KMeans
图构造函数,那么仍然可以通过读取全局TF变量´clusters获得聚类中心: 0´:
tf_vble_cluster_centers = tf.global_variables('clusters:0')[0] # get the global TF variable 'clusters:0'
cluster_centers = sess.run(tf_vble_cluster_centers) # evaluate its contents
print(cluster_centers.shape) # nr. rows = nr. of clusters, nr. columns = nr. dimensions
print(cluster_centers[0]) # print cluster centers for the first cluster