DNNRegressor培训出错

时间:2017-10-27 10:54:01

标签: python machine-learning tensorflow regression pipeline

我在估算员DNNRegressor的帮助下构建回归模型。 以下是代码

import tensorflow as tf

DATA_PATH = 'train_data/train_1.csv'
BATCH_SIZE = 5
N_FEATURES = 3963

def batch_generator(filenames):
    """ filenames is the list of files you want to read from. 
    In this case, it contains only heart.csv
    """
    filename_queue = tf.train.string_input_producer(filenames)
    reader = tf.TextLineReader(skip_header_lines=1) # skip the first line in the file
    _, value = reader.read(filename_queue)
    record_defaults = [[1.0] for _ in range(N_FEATURES)]

    # read in the rows of data
    content = tf.decode_csv(value, record_defaults=record_defaults) 

    # pack all features into a tensor
    features = tf.stack(content[:N_FEATURES])

    # assign the last column to label
    label = content[1]

    # minimum number elements in the queue after a dequeue, used to ensure 
    # that the samples are sufficiently mixed
    # I think 10 times the BATCH_SIZE is sufficient
    min_after_dequeue = 10 * BATCH_SIZE

    # the maximum number of elements in the queue
    capacity = 20 * BATCH_SIZE

    # shuffle the data to generate BATCH_SIZE sample pairs
    data_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=BATCH_SIZE, 
                                        capacity=capacity, min_after_dequeue=min_after_dequeue)

    return data_batch, label_batch

def generate_batches():
    regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,hidden_units=[10,10,10],model_dir='alg_model4')
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in range(4): # generate 10 batches
            regressor.train(input_fn=sess.run(input_fn()),steps=2)
        coord.request_stop()
        coord.join(threads)

def main():
    generate_batches()


if __name__ == '__main__':
    main()

以下是流程: -

  • 首先,我从包含多个文件的目录中读取数据 前缀为" train _"。
  • 模式如火车_ * .csv
  • 总共包含3963列。
  • 第二列是因变量All都是整数类型
  • 我需要批量读取此数据集并将其输入DNNRegressor以训练模型
  

问题是它抛出以下输出并出现错误: -   信息:tensorflow:使用默认配置。信息:tensorflow:使用config:   {' _log_step_count_steps':100,' _keep_checkpoint_max':5,   ' _save_checkpoints_secs':600,' _tf_random_seed':1,   ' _save_summary_steps':100,' _model_dir':' alg_model4',   ' _save_checkpoints_steps':无,' _session_config':无,   ' _keep_checkpoint_every_n_hours':10000}   (TensorShape([Dimension(None),Dimension(3963)]),   TensorShape([Dimension(None)]))INFO:tensorflow:报告错误   协调员:,Dequeue   操作被取消[[Node:ReaderReadV2_7 =   ReaderReadV2 [_device =" /作业:本地主机/复制:0 /任务:0 / CPU:0"](TextLineReaderV2_7,   input_producer_7)]]   -------------------------------------------------- ------------------------- TypeError Traceback(最近一次调用   last)/ usr / lib / python3.5 /inspect.py in getfullargspec(func)1088
  skip_bound_arg =假    - > 1089 sigcls =签名)1090除了例外情况,例如:

     _signature_from_callable中的

/usr/lib/python3.5/inspect.py(obj,   follow_wrapper_chains,skip_bound_arg,sigcls)2155如果没有   可调用(OBJ):    - > 2156引发TypeError(' {!r}不是可调用对象' .format(obj))2157

     

TypeError :( array([[0.,1。,0.,...,0.,0.,0。],          [0.,1.,0.,...,0.,0.,0。],          [0.,1.,0.,...,0.,0.,0。],          [0.,1.,0.,...,0.,0.,0。],          [0.,1.,0.,...,1.,0.,0。]],dtype = float32),array([4261,2203,4120,4049,1414]))不是可调用对象

     

上述异常是导致以下异常的直接原因:

     

TypeError Traceback(最近一次调用   最后)in()         4         5如果名称 ==' 主要':   ----> 6 main()

     

在main()中         1 def main():   ----> 2 generate_batches()         3         4         5如果名称 ==' 主要':

     

在generate_batches()中         5个主题= tf.train.start_queue_runners(coord = coord)         6为_在范围内(4):#生成10批次   ----> 7 regressor.train(input_fn = sess.run(input_fn()),steps = 2)         8 coord.request_stop()         9 coord.join(线程)

     

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py   在火车上(self,input_fn,hooks,steps,max_steps)       239 hooks.append(training.StopAtStepHook(steps,max_steps))       240    - > 241 loss = self._train_model(input_fn = input_fn,hooks = hooks)       242 logging.info('最后一步的损失:%s。',损失)       243回归自我

     

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py   在_train_model中(self,input_fn,hooks)       626 global_step_tensor = self._create_and_assert_global_step(g)       627个功能,标签= self._get_features_and_labels_from_input_fn(    - > 628 input_fn,model_fn_lib.ModeKeys.TRAIN)       629 estimator_spec = self._call_model_fn(功能,标签,       630 model_fn_lib.ModeKeys.TRAIN)

     

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py   in _get_features_and_labels_from_input_fn(self,input_fn,mode)       497       498 def _get_features_and_labels_from_input_fn(self,input_fn,mode):    - > 499 result = self._call_input_fn(input_fn,mode)       如果是实例500(结果,(列表,元组)):       501如果len(结果)!= 2:

     

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py   在_call_input_fn中( 解析参数失败 )       576"""       577 del mode#unused    - > 578 input_fn_args = util.fn_args(input_fn)       579 kwargs = {}       580 if' params'在input_fn_args中:

     

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/util.py   在fn_args(fn)中        55        56#处理功能。   ---> 57返回元组(tf_inspect.getargspec(fn).args)

     

/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/tf_inspect.py   在getargspec(对象)        43个装饰器,target = tf_decorator.unwrap(object)        44返回下一个((d.decorator_argspec for d in decorators   ---> 45如果d.decorator_argspec不是None),_ inspect.getargspec(target))        46        47

     getpgspec(func)中的

/usr/lib/python3.5/inspect.py 1041   stacklevel = 2)1042 args,varargs,varkw,defaults,kwonlyargs,   kwonlydefaults,ann = \    - > 1043 getfullargspec(func)1044如果kwonlyargs或ann:1045引发ValueError(" Function只有关键字参数或者   注释"

     

/usr/lib/python3.5/inspect.py in getfullargspec(func)1093

     

别的。因此,为了完全向后兼容,我们在这里捕获所有1094个可能的异常,并重新引发TypeError。

     

- > 1095从ex 1096 1097引发TypeError('不支持的可调用')args = []

     

TypeError:不支持的可调用

1 个答案:

答案 0 :(得分:0)

不确定input_fn的来源,但sess.run(input_fn())看起来不正确。 Estimator将评估input_fn