如何在`.fit()`中创建`input_fn`作为参数?

时间:2017-06-27 11:21:31

标签: python tensorflow

我正在尝试使用我标记的一些图像来训练cnn模型。我是TensorFlow的新手。这就是我所做的:

def read_labeled_image_list(image_list_file):
    f = open(image_list_file, 'r')
    filenames = []
    labels = []
    for line in f:
        filename, label = line[:-1].split(' ')
        filenames.append(filename)
        index0 = 1 if int(label) == 0 else 0
        index1 = 1 if int(label) == 1 else 0
        labels.append([index0, index1])
    return filenames, labels

def read_images_from_disk(input_queue):
    label = input_queue[1]
    file_contents = tf.read_file(input_queue[0])
    example = tf.image.decode_jpeg(file_contents, channels=1)
    return example, label

使用“read_images_from_disk”作为我的input_fn:

image_list, label_list = 
          read_labeled_image_list("./images_training/training_list.txt")

images = tf.constant(image_list, dtype=tf.string)
labels = tf.constant(label_list, dtype=tf.int32)

# Makes an input queue
input_queue = tf.train.slice_input_producer([images, labels],
                                            num_epochs=30,
                                                shuffle=True)

image, label = read_images_from_disk(input_queue)

# Train the model
graph_classifier.fit(
    input_fn=read_images_from_disk(input_queue),
    steps=20000,
    monitors=[logging_hook])

我收到以下错误:

features, labels = input_fn()
TypeError: 'tuple' object is not callable

1 个答案:

答案 0 :(得分:0)

错误的原因是input_fn方法中的fit参数应该是可调用的。然后你可以尝试:

def read_images_from_disk(input_queue):
    label = input_queue[1]
    file_contents = tf.read_file(input_queue[0])
    example = tf.image.decode_jpeg(file_contents, channels=1)
    return example, label

def my_input_func():
 return read_images_from_disk(input_queue)

# Train the model
graph_classifier.fit(
    input_fn=my_input_func,
    steps=20000,
    monitors=[logging_hook])

我还建议仔细阅读 input_func上的the official doc