我正在使用带有conda和python 3的TF版本1.12。 我的问题与tf.contrib.factorization.KMeansClustering的model_dir值有关:如何对model_dir值使用字符串占位符?
这里是上下文:我已经在不同情况下对KMeans进行了预训练,检查点在不同的model_dir中。
我想根据每种情况在图形中使用这些经过预训练的模型的预测,因此我需要在该图形中放置KMeansClustering,它可以接受不同的model_dirs。
在我定义的图中:
ckpt_ph = tf.placeholder(tf.string)
...
kmeans = KMeansClustering(5, model_dir=ckpt_ph,distance_metric='cosine')
def input_fn():
return tf.train.limit_epochs(tf.convert_to_tensor(x, dtype=tf.float32), num_epochs=1)
centers_idx = list(kmeans.predict(input_fn,predict_keys='cluster_index',checkpoint_path=ckpt_ph,yield_single_examples=False))[0]['cluster_index']
centers_val = kmeans.cluster_centers()
...
然后运行:
...
for ind in range(nb_cases):
...
sess.run([...], feed_dict={..., ckpt_ph: km_ckpt[ind]})
...
km_ckpt是我要在每种情况下使用的预先训练的KMeansClustering检查点路径的列表。
我得到的错误是:
Traceback (most recent call last):
File "main.py", line 28, in <module>
tf.app.run()
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
File "main.py", line 23, in main
launch_training()
File "main.py", line 14, in launch_training
train_mnist.train_model()
File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 355, in train_model
X_r = SR(X_tensor)
File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 316, in __call__
kmeans = KMeansClustering(FLAGS.km_k, model_dir=ckpt_ph, distance_metric='cosine')
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\factorization\python\ops\kmeans.py", line 423, in __init__
config=config)
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 189, in __init__
model_dir)
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1665, in maybe_overwrite_model_dir_and_session_config
if model_dir:
File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 671, in __bool__
raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
我认为问题在于,在KMeansClustering和KMeansClustering.predict中,model_dir期望使用Python bool或字符串,并且给了他一个Tensor,但随后我看不到在图形内使用预先训练的KMeansClustering的方法。 。 预先感谢您的帮助!