Tensorflow,循环/喂食/训练的麻烦

时间:2017-11-12 22:02:46

标签: python-3.x tensorflow neural-network deep-learning

所以我知道代码有点乱,我希望每个人都能原谅我。 我试图定义一个非常简单的模型,我对模型的输入很有形 [1,9],但从输入管道获取的批量大小是[batch_size,9],我熟悉切片机制,但我真的不知道实际的循环在哪里必须去(我明白我可能想根据迭代切片)。像tf.slice(input_batch,[i,0],[i,10])之类的东西可以解决这个问题,但我仍然对于将输入批次传递到何处以及如何定义方式感到困惑我的图表以特殊方式处理批次。

我不认为我这样做,因为我提供给图表的张量应该以某种方式最终出现在定义的占位符中。

任何帮助都将不胜感激,也有任何关于如何解决这个问题的建议,以获得更强大和可行的解决方案。

#!/usr/local/bin/python3

from __future__ import print_function
import numpy as np
import tensorflow as tf
import math as math
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('dataset')
args = parser.parse_args()

def file_len(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1

def read_from_csv(filename_queue):
  reader = tf.TextLineReader(skip_header_lines=1)
  _, csv_row = reader.read(filename_queue)
  record_defaults = [[5.0] for col in range(11)]
  cols = []
  cols = tf.decode_csv(csv_row, record_defaults=record_defaults, field_delim=',')
  features = tf.stack(cols[1:10])
  #tf.reshape(features, shape=[1,8])
  label = tf.stack(cols[10])
  return features, label

def input_pipeline(batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer([args.dataset], num_epochs=num_epochs, shuffle=False)
  print(args.dataset)
  example, label = read_from_csv(filename_queue)
  print("***************************************************************")
  print(example)
  print(label)
  print("***************************************************************")
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

def def_graph():
  with tf.Graph().as_default():
    input_placeholder = tf.placeholder(shape=[1, 9], dtype=tf.float32, name="input_placeholder")
    target_placeholder = tf.placeholder(shape=[1], dtype=tf.float32, name="target_placeholder")
    sess = tf.Session()
    w1 = tf.Variable(tf.random_uniform(shape=[9, 30], dtype=tf.float32), name="weights1")
    b1 = tf.Variable(tf.random_uniform(shape=[1, 30], dtype=tf.float32), name="bias1")

    h1 = tf.sigmoid(tf.matmul(input_placeholder, w1) + b1, name="hidden1")
    w2 = tf.Variable(tf.random_uniform(shape=[30, 1], dtype=tf.float32), name="weigts2")
    output = tf.matmul(h1, w2, name="output")
    loss = tf.reduce_mean(tf.squared_difference(target_placeholder, output), name="loss")

    optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
    train_op = optimizer.minimize(loss, name="train_op")

    init = tf.global_variables_initializer()
    sess.run(init)
  return sess, train_op

batch_size=20

sess, train_op = def_graph()

with sess as sess:
  init=tf.global_variables_initializer()
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)
  goal=tf.constant([0.2])

  # start populating filename queue

  try:
    while not coord.should_stop():
      print("now training batch")
      input_batch, target_batch = input_pipeline(batch_size)
      tf.Print(input_batch, [input_batch])
      print(input_batch.eval())
      for step in range(batch_size):
        instance = tf.get_session_handle(tf.slice(input_batch, [step, 0], [step, 10]))
        instance_target = tf.get_session_handle(tf.slice(target_batch, [step], [1]))
        instance=sess.run(instance)
        instance_target=sess.run(instance_target)
        feed_dict = {tf.get_variable(name="input_placeholder", shape=[1,9]): instance,
              tf.get_variable(name="target_placeholder", shape=[1]): instance_target}
        sess.run(train_op, feed_dict=feed_dict)
        if (goal > tf.get_variable(name="loss", shape=[1])):
          print("*************************************")
          print("\n\n")
          print("convergence reached!")
          print("\n\n")
          print("*************************************")
          break

  except tf.errors.OutOfRangeError:
    print('Done training, epoch reached')
  finally:
    coord.request_stop()

  coord.join(threads)

0 个答案:

没有答案