如何在tf.contrib.factorization.KMeansClustering中为model_dir使用字符串占位符?

时间:2019-02-05 17:46:59

标签: tensorflow k-means

我正在使用带有conda和python 3的TF版本1.12。 我的问题与tf.contrib.factorization.KMeansClustering的model_dir值有关:如何对model_dir值使用字符串占位符?

这里是上下文:我已经在不同情况下对KMeans进行了预训练,检查点在不同的model_dir中。

我想根据每种情况在图形中使用这些经过预训练的模型的预测,因此我需要在该图形中放置KMeansClustering,它可以接受不同的model_dirs。

在我定义的图中:

ckpt_ph = tf.placeholder(tf.string)
...
kmeans = KMeansClustering(5, model_dir=ckpt_ph,distance_metric='cosine')
def input_fn():
    return tf.train.limit_epochs(tf.convert_to_tensor(x, dtype=tf.float32), num_epochs=1)
centers_idx = list(kmeans.predict(input_fn,predict_keys='cluster_index',checkpoint_path=ckpt_ph,yield_single_examples=False))[0]['cluster_index']
centers_val = kmeans.cluster_centers()
...

然后运行:

...
for ind in range(nb_cases):
    ...
    sess.run([...], feed_dict={..., ckpt_ph: km_ckpt[ind]})
...

km_ckpt是我要在每种情况下使用的预先训练的KMeansClustering检查点路径的列表。

我得到的错误是:

Traceback (most recent call last):
  File "main.py", line 28, in <module>
    tf.app.run()
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
    _sys.exit(main(argv))
  File "main.py", line 23, in main
    launch_training()
  File "main.py", line 14, in launch_training
    train_mnist.train_model()
  File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 355, in train_model
    X_r = SR(X_tensor)
  File "C:\Users\Denis\ML\ScatteringReconstruction\src\model\train_mnist.py", line 316, in __call__
    kmeans = KMeansClustering(FLAGS.km_k, model_dir=ckpt_ph, distance_metric='cosine')
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\factorization\python\ops\kmeans.py", line 423, in __init__
    config=config)
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 189, in __init__
    model_dir)
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1665, in maybe_overwrite_model_dir_and_session_config
    if model_dir:
  File "C:\Users\Denis\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 671, in __bool__
    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

我认为问题在于,在KMeansClustering和KMeansClustering.predict中,model_dir期望使用Python bool或字符串,并且给了他一个Tensor,但随后我看不到在图形内使用预先训练的KMeansClustering的方法。 。 预先感谢您的帮助!

0 个答案:

没有答案