使用tensorflow StagingArea和数据集api

时间:2017-12-07 09:38:45

标签: tensorflow tensorflow-datasets

尝试将tensorflow临时区域与数据集api结合使用。

compute_stage_put_op = compute_stage.put(iterator.get_next())
if compute_stage_put_op.type == 'Stage':
   compute_stage_ops.append(compute_stage_put_op)

完成几个步骤后得到以下错误。

ValueError: Fetch argument <tf.Operation 'group_deps' type=NoOp> 
            cannot be interpreted as a Tensor. (Operation name: 
            "group_deps" op: "NoOp")

堆栈追踪:

Traceback (most recent call last):

文件&#34; timit_trainer.py&#34;,第5行,in     timit_trainer.train()   火车上的文件&#34; /mnt/sdc/nlp/workspace/hci/nlp/mapc/core/model/model.py" ;,第43行     hparams = self.hyper_params#HParams   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/learn_runner.py" ;,第218行,在运行中     return _execute_schedule(实验,日程安排)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/learn_runner.py",第46行,在_execute_schedule中     返回任务()   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py" ;,第625行,在train_and_evaluate中     self.train(delay_secs = 0)   火车上的文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py" ;,第367行     hooks = self._train_monitors + extra_hooks)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py" ;,第807行,在_call_train中     钩=钩)   火车上的文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py" ;, 302行     loss = self._train_model(input_fn,hooks,saving_listeners)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py",第783行,在_train_model中     _,loss = mon_sess.run([estimator_spec.train_op,estimator_spec.loss])   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py",第521行,在运行中     run_metadata = run_metadata)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py" ;,第892行,在运行中     run_metadata = run_metadata)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py" ;,第967行,在运行中     提高six.reraise(* original_exc_info)   文件&#34; /usr/local/lib/python3.5/dist-packages/six.py" ;,第693行,重新加入     提高价值   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py",第952行,在运行中     return self._sess.run(* args,** kwargs)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py" ;,第1032行,在运行中     run_metadata = run_metadata))   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py",第1196行,在after_run中     indu_stop = m.step_end(self._last_step,result)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py" ;,第356行,in step_end     return self.every_n_step_end(step,output)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py" ;,行694,in every_n_step_end     validation_outputs = self._evaluate_estimator()   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py",第665行,在_evaluate_estimator中     名称= self.name)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py",第355行,在评估中     名称=名称)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py",第839行,在_evaluate_model中     配置= self._session_config)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/evaluation.py",第206行,在_evaluate_once中     session.run(eval_ops,feed_dict)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py",第521行,在运行中     run_metadata = run_metadata)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py" ;,第892行,在运行中     run_metadata = run_metadata)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py" ;,第967行,在运行中     提高six.reraise(* original_exc_info)   文件&#34; /usr/local/lib/python3.5/dist-packages/six.py" ;,第693行,重新加入     提高价值   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py",第952行,在运行中     return self._sess.run(* args,** kwargs)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py" ;,第1024行,在运行中     run_metadata = run_metadata)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/monitored_session.py" ;,第827行,在运行中     return self._sess.run(* args,** kwargs)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第889行,在运行中     run_metadata_ptr)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第1105行,在_run中     self._graph,fetches,feed_dict_tensor,feed_handles = feed_handles)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第414行, init     self._fetch_mapper = _FetchMapper.for_fetch(fetches)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第236行,in_fetch     return _DictFetchMapper(fetch)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",第374行, init     用于获取fetches.values()]   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第374行,在     用于获取fetches.values()]   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第234行,in_fetch     return _ListFetchMapper(fetch)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",第341行, init     self._mappers = [_FetchMapper.for_fetch(fetch)for fetches in fetches]   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第341行,在     self._mappers = [_FetchMapper.for_fetch(fetch)for fetches in fetches]   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第234行,in_fetch     return _ListFetchMapper(fetch)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",第341行, init     self._mappers = [_FetchMapper.for_fetch(fetch)for fetches in fetches]   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第341行,在     self._mappers = [_FetchMapper.for_fetch(fetch)for fetches in fetches]   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py" ;,第242行,in_fetch     return _ElementFetchMapper(fetches,contraction_fn)   文件&#34; /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",第278行, init     &#39;张量。 (%S)&#39; %(fetch,str(e))) ValueError:Fetch参数不能解释为Tensor。 (操作名称:&#34; StagingArea_put&#34;

代码:

def读取(self,category:DatasetCategory,devices:list,proc_device:str,shuffle = False):

    batch_size = ds.BATCH_SIZE

    record_store_exists, record_store = self.__get_store_info(store_path=fu.join_path(self.store_dir, self.store_name + '_' + category.name + '.tfrecord'), create_new=False)
    logger.info('Reading records. category: {}, store_exists:{}, store;{}'.format(category.name, str(record_store_exists), record_store))
    if not record_store_exists:
        self.__process()
        record_store_exists, record_store = self.__get_store_info(store_path=fu.join_path(self.store_dir, self.store_name + '_' + category.name + '.tfrecord'), create_new=False)

    iterator_init_hook = SessionRunHook()
    map_fn = self.__parse_function

    gpu_copy_stage_ops = []
    gpu_compute_stage_ops = []

    def input_fn():

        file_names = tf.placeholder(dtype=tf.string, shape=[None], name='data_store')
        dataset = tf.data.TFRecordDataset(filenames=file_names, buffer_size=2000000000)  # 2.0GB
        dataset = dataset.map(map_func=map_fn, num_parallel_calls=tf.constant(value=20000, dtype=tf.int32))
        if shuffle:
            dataset = dataset.shuffle(buffer_size=tf.constant(value=1000 * batch_size, dtype=tf.int64))

        dataset = dataset.repeat(None)  # Infinite iterations
        dataset = dataset.padded_batch(batch_size=tf.constant(value=batch_size, dtype=tf.int64), padded_shapes=([None, ds.NUM_INPUT_FEATURES], [], [None], []))

        iterator = dataset.make_initializable_iterator()
        iterator_init_hook.run_func = lambda session: session.run(iterator.initializer, feed_dict={file_names: [record_store]})

        gpu_copy_stage = StagingArea(dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
                                     shapes=[[batch_size, None, ds.NUM_INPUT_FEATURES], [batch_size], [batch_size, None], [batch_size]])

        gpu_compute_stage = StagingArea(dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
                                        shapes=[[batch_size, None, ds.NUM_INPUT_FEATURES], [batch_size], [batch_size, None], [batch_size]])

        features_dict = {}
        labels_dict = {}
        for index, device in enumerate(devices):
            with tf.device(proc_device):
                gpu_copy_stage_ops.append(gpu_copy_stage.put(iterator.get_next()))
            with tf.device(device):
                gpu_compute_stage_ops.append(gpu_compute_stage.put(gpu_copy_stage.get()))
                source, source_len, target, target_len = gpu_compute_stage.get()
                if ds.USE_WARP_CTC:
                    targets = []
                    for bi in range(batch_size):
                        targets.append(target[bi])
                    target = tf.concat(targets, axis=0)

                features_dict[device] = {'source': source, 'source_len': source_len}
                labels_dict[device] = {'target': target, 'target_len': target_len}

        return features_dict, labels_dict

    copy_stage_hook = StepOpsRunHook(ops=[gpu_copy_stage_ops], every_n_secs=1)
    compute_stage_hook = StepOpsRunHook(ops=[gpu_compute_stage_ops], every_n_steps=1)

    return input_fn, [iterator_init_hook, copy_stage_hook, compute_stage_hook]

0 个答案:

没有答案