使用数据集API从张量流中的冻结模型推断

时间:2018-09-21 12:29:41

标签: python tensorflow tensor tensorflow-serving tensorflow-datasets

我试图从张量流宽和深模型的冻结图(使用freeze_graph.py生成)中推断出。我使用数据集API来解析test.csv文件。由于feed dict仅接受numpy数组且Dataset的内容为张量,因此我得到批处理next_element = iterator.get_next()和batch = sess.run(next_element)来获取numpy值,然后使用feed dict将其输入占位符。但这在大型数据集上运行时并不能为我提供良好的吞吐量,因为将张量转换为numpy数组然后馈入占位符不是一种有效的方法。有一种有效的方法可以做到这一点。

def input_fn(data_file, num_epochs, shuffle, batch_size):
  """Generate an input function for the Estimator."""
  assert tf.gfile.Exists(data_file), (
      '%s not found. Please make sure you have run census_dataset.py and '
      'set the --data_dir argument to the correct path.' % data_file)

  def parse_csv(value):
    tf.logging.info('Parsing {}'.format(data_file))
    cont_defaults = [ [0.0] for i in range(1,14) ]
    cate_defaults = [ [" "] for i in range(1,27) ]
    label_defaults = [ [0] ]
    column_headers = TRAIN_DATA_COLUMNS
    record_defaults = label_defaults + cont_defaults + cate_defaults
    columns = tf.decode_csv(value, record_defaults=record_defaults)
    all_columns = dict(zip(column_headers, columns))
    labels = all_columns.pop(LABEL_COLUMN[0])
    features = all_columns
    return features, labels

  # Extract lines from input files using the Dataset API.
  dataset = tf.data.TextLineDataset(data_file)

  if shuffle:
    dataset = dataset.shuffle(buffer_size=2000)

  dataset = dataset.map(parse_csv, num_parallel_calls=8)

  # We call repeat after shuffling, rather than before, to prevent separate
  # epochs from blending together.
  dataset = dataset.repeat(num_epochs)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(56)
  return dataset 



with tf.Session(graph=graph) as sess:
  res_dataset = input_fn(predictioninputfile,1,False,batch_size)
  iterator = res_dataset.make_one_shot_iterator()
  next_element = iterator.get_next()
  inference_start = time.time()
  for i in range(no_of_batches):
    batch=sess.run(next_element)
    features,actual_label=batch[0],batch[1]
    #print("features",features)
    logistic = sess.run(output_tensor, dict(zip(input_tensor,list(features.values()))))

0 个答案:

没有答案