我尝试使用此处https://github.com/RobRomijnders/LSTM_tsc中的代码训练NN。
经过大量更改以适合我的数据后,我收到此错误:
TypeError:Fetch参数None具有无效类型
我找到了答案,就像我之前发现的错误一样,但是我来到这里: Tensorflow TypeError on session.run arguments/output
我认为我的代码不会产生(至少我找不到)
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import matplotlib.pyplot as plt
import sys
sys.path.append('E:\\Descargas\\recommendersystems\\train')
from tsc_model import Model,sample_batch,load_data,check_test
config = {} #Put all configuration information into the dict
config['num_layers'] = 3 #number of layers of stacked RNN's
config['hidden_size'] = 120 #memory cells in a layer
config['max_grad_norm'] = 5 #maximum gradient norm during training
config['batch_size'] = batch_size = 30
config['learning_rate'] = .005
config['num_classes'] = 2
max_iterations = 3000
dropout = 0.8
ratio = np.array([0.8,0.9]) #Ratios where to split the training and validation set
direc = 'E:\\Descargas\\recommendersystems\\ets_challenge\\train\\train3.csv'
X_train,X_val,X_test,y_train,y_val,y_test = load_data(direc,ratio)
N,sl = X_train.shape
config['sl'] = sl = X_train.shape[1]
config['num_classes'] = num_classes = len(np.unique(y_train))
# Collect the costs in a numpy fashion
epochs = np.floor(batch_size*max_iterations / N)
print('Train %.0f samples in approximately %d epochs' %(N,epochs))
perf_collect = np.zeros((4,int(np.floor(max_iterations /100))))
#Instantiate a model
model = Model(config)
sess = tf.Session() #Depending on your use, do not forget to close the session
writer = tf.summary.FileWriter("/home/rob/Dropbox/ml_projects/LSTM/log_tb", sess.graph) #writer for Tensorboard
sess.run(model.init_op)
step = 0
cost_train_ma = -np.log(1/float(num_classes)+1e-9) #Moving average training cost
acc_train_ma = 0.0
try:
for i in range(max_iterations):
X_batch, y_batch = sample_batch(X_train,y_train,batch_size)
#Next line does the actual training
cost_train, acc_train,_ = sess.run([model.cost,model.accuracy, model.train_op],feed_dict = {model.input: X_batch,model.labels: y_batch,model.keep_prob:dropout})
cost_train_ma = cost_train_ma*0.99 + cost_train*0.01
acc_train_ma = acc_train_ma*0.99 + acc_train*0.01
if i%100 == 0:
#Evaluate training performance
perf_collect[0,step] = cost_train
perf_collect[1,step] = acc_train
#Evaluate validation performance
X_batch, y_batch = sample_batch(X_val,y_val,batch_size)
cost_val, summ,acc_val = sess.run([model.cost,model.merged,model.accuracy],feed_dict = {model.input: X_batch, model.labels: y_batch, model.keep_prob:1.0})
perf_collect[1,step] = cost_val
perf_collect[2,step] = acc_val
print('At %5.0f/%5.0f: COST %5.3f/%5.3f(%5.3f) -- Acc %5.3f/%5.3f(%5.3f)' %(i,max_iterations,cost_train,cost_val,cost_train_ma,acc_train,acc_val,acc_train_ma))
#Write information to TensorBoard
writer.add_summary(summ, i)
writer.flush()
step +=1
except KeyboardInterrupt:
#Pressing ctrl-c will end training. This try-except ensures we still plot the performance
pass
acc_test,cost_test = check_test(model,sess,X_test,y_test)
epoch = float(i)*batch_size/N
print('After training %.1f epochs, test accuracy is %5.3f and test cost is %5.3f'%(epoch,acc_test,cost_test))
plt.plot(perf_collect[0],label='Train')
plt.plot(perf_collect[1],label = 'Valid')
plt.plot(perf_collect[2],label = 'Valid accuracy')
plt.axis([0, step, 0, np.max(perf_collect)])
plt.legend()
plt.show()
这是错误跟踪
TypeError Traceback (most recent call last)
<ipython-input-1-200bd46aa85d> in <module>()
64 #Evaluate validation performance
65 X_batch, y_batch = sample_batch(X_val,y_val,batch_size)
---> 66 cost_val, summ,acc_val = sess.run([model.cost,model.merged,model.accuracy],feed_dict = {model.input: X_batch, model.labels: y_batch, model.keep_prob:1.0})
67 perf_collect[1,step] = cost_val
68 perf_collect[2,step] = acc_val
\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
765 try:
766 result = self._run(None, fetches, feed_dict, options_ptr,
--> 767 run_metadata_ptr)
768 if run_metadata:
769 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
950
951 # Create a fetch handler to take care of the structure of fetches.
--> 952 fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
953
954 # Run request and get response.
\lib\site-packages\tensorflow\python\client\session.py in __init__(self, graph, fetches, feeds)
406 """
407 with graph.as_default():
--> 408 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
409 self._fetches = []
410 self._targets = []
\lib\site-packages\tensorflow\python\client\session.py in for_fetch(fetch)
228 elif isinstance(fetch, (list, tuple)):
229 # NOTE(touts): This is also the code path for namedtuples.
--> 230 return _ListFetchMapper(fetch)
231 elif isinstance(fetch, dict):
232 return _DictFetchMapper(fetch)
\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches)
335 """
336 self._fetch_type = type(fetches)
--> 337 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
338 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
339
\lib\site-packages\tensorflow\python\client\session.py in <listcomp>(.0)
335 """
336 self._fetch_type = type(fetches)
--> 337 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
338 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
339
\lib\site-packages\tensorflow\python\client\session.py in for_fetch(fetch)
225 if fetch is None:
226 raise TypeError('Fetch argument %r has invalid type %r' %
--> 227 (fetch, type(fetch)))
228 elif isinstance(fetch, (list, tuple)):
229 # NOTE(touts): This is also the code path for namedtuples.
TypeError: Fetch argument None has invalid type <class 'NoneType'>
任何帮助表示赞赏!
答案 0 :(得分:0)
从错误消息中可以看出model.cost
,model.merged
或model.accuracy
中的一个是None
。查看original source,Model
对象的所有这些属性似乎都在Model.__init__()
中分配:self.cost
已分配here,self.merged
已分配here,self.accuracy
已分配here。
您对代码的一个更改(使其适合您的数据)很可能已更改该构造函数中的控制流,以便不再发生其中一个分配。我建议首先确定哪个属性为None
,然后通过Model.__init__()
进行追踪,以发现未设置的原因。