使用函数生成TF学习DNN对象

时间:2017-02-22 02:20:37

标签: python tensorflow tflearn

我正在使用Jupyter笔记本试图为一些暴力试验和错误测试生成一长串TF Learn DNN对象(我知道这不是最有效的方法,只是试图展示一个例子) 。数据遵循泰坦尼克号快速入门教程。

我有一个函数,给定一堆参数,应该返回一个tflearn.DNN()对象:

def make_fully_connected(input_shape, output_shape, activation, layers, nodes, dropout, optimizer, loss):
    tflearn.init_graph()
    net = tflearn.input_data(shape=[None, input_shape])
    for l in range(layers):
        net = tflearn.fully_connected(net, nodes)
        if (dropout != 0) and (l%2==1):
            net = tflearn.dropout(net, dropout)
    net = tflearn.fully_connected(net, output_shape, activation=activation)
    net = tflearn.regression(net, optimizer=optimizer, loss=loss)
    return tflearn.DNN(net)

然后我使用该函数生成特定模型:

model = make_fully_connected(6, 2, 'softmax', 2, 32, 0, 'adam', 'categorical_crossentropy')
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)
score = model.evaluate(data, labels)

但我收到一条可爱的错误消息,将我带入TF学习代码,我很快就迷路了:

IndexError                                Traceback (most recent call last)
<ipython-input-15-79e1d2acc8bf> in <module>()
      1 model = make_fully_connected(6, 2, 'softmax', 2, 32, 0, 'adam', 'categorical_crossentropy')
----> 2 model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)
      3 score = model.evaluate(data, labels)
      4 print('| Score: %.4f' % score, end='')

/usr/local/lib/python3.5/dist-packages/tflearn/models/dnn.py in fit(self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, validation_batch_size, run_id, callbacks)
    181         # TODO: check memory impact for large data and multiple optimizers
    182         feed_dict = feed_dict_builder(X_inputs, Y_targets, self.inputs,
--> 183                                       self.targets)
    184         feed_dicts = [feed_dict for i in self.train_ops]
    185         val_feed_dicts = None

/usr/local/lib/python3.5/dist-packages/tflearn/utils.py in feed_dict_builder(X, Y, net_inputs, net_targets)
    287                 X = [X]
    288             for i, x in enumerate(X):
--> 289                 feed_dict[net_inputs[i]] = x
    290         else:
    291             # If a dict is provided

IndexError: list index out of range

回归模型的功能超出了TF Learn的范围吗?或者还有其他障碍吗?

1 个答案:

答案 0 :(得分:0)

尝试重新启动内核并清除输出并再次运行。我也遇到了这个问题,这个解决方案对我有用。这是因为你多次运行模型并且它已经崩溃了。