如何在TFlearn中使用两个网?

时间:2016-10-06 04:14:21

标签: python machine-learning tensorflow

我正在努力训练DQN玩Tic-Tac-Toe。我训练它玩X(而O动作是随机的)。经过12小时的训练,它可以播放,但不是完美的。现在我想同时训练两个网 - 一个用于X移动,一个用于O移动。 但是当我尝试在第二个网络上执行model.predict(state)时,我得到的错误如下:

ValueError: Cannot feed value of shape (9,) for Tensor 'InputData/X:0', which has shape '(?, 9)'

但我知道网络定义和数据维度是相同的。定义两个DNN有一些东西。

以下是一个通用示例:

import tflearn
import random

X = [[random.random(),random.random()] for x in range(1000)]
#reverse values order like [1,0] -> [0,1]
Y = [[x[1],x[0]] for x in X]

n = tflearn.input_data(shape=[None,2])
n = tflearn.fully_connected(n, 2)
n = tflearn.regression(n)
m = tflearn.DNN(n)

m.fit(X, Y, n_epoch = 20)
#should print like [0.1,0.9]
print(m.predict([[0.9,0.1]]))

n2 = tflearn.input_data(shape=[None,2])
n2 = tflearn.fully_connected(n2, 2)
n2 = tflearn.regression(n2)
m2 = tflearn.DNN(n2)

# set second element value to first e.g. [1,0] -> [1,1]
Y = [[x[0],x[0]] for x in X]

m2.fit(X, Y, n_epoch = 20)
#should print like [0.9,0.9]
print(m2.predict([[0.9,0.1]]))

错误将如下:

Traceback (most recent call last):
  File "2_dnn_test.py", line 25, in <module>
    m2.fit(X, Y, n_epoch = 20)
  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 157, in fit
    self.targets)
  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/utils.py", line 267, in feed_dict_builder
    feed_dict[net_inputs[i]] = x
IndexError: list index out of range

错误是不同的,因为在我的tic-tac-toe中,我比第一次调整()更早地调用第二DNN上的预测。如果我在我的示例中注释m2.fit(X, Y, n_epoch = 20),我会得到同样的错误:

Traceback (most recent call last):
  File "2_dnn_test.py", line 27, in <module>
    print(m2.predict([[0.9,0.1]]))
  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/models/dnn.py", line 204, in predict
    return self.predictor.predict(feed_dict)
  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tflearn/helpers/evaluator.py", line 69, in predict
    o_pred = self.session.run(output, feed_dict=feed_dict).tolist()
  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 372, in run
    run_metadata_ptr)
  File "/home/cpro/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 625, in _run
    % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (2,) for Tensor 'InputData/X:0', which has shape '(?, 2)'

因此两个相同的网络不能同时工作。如何使它们都有效?

BTW示例未获得预期的预测结果:)

1 个答案:

答案 0 :(得分:0)

看起来我应该添加

with tf.Graph().as_default():
    #define model here

防止TFLearn将两个模型附加到默认图形。一切都有用。