如何在input_fn中打包TensorFlow Estimator的功能

时间:2018-07-24 21:13:56

标签: python tensorflow


import Tkinter as tk
from PIL import ImageTk, Image

path = 'C:/xxxx/xxxx.jpg'

root = tk.Tk()
img = ImageTk.PhotoImage(Image.open(path))
panel = tk.Label(root, image = img)
panel.pack(side = "bottom", fill = "both", expand = "yes")

当前def input_fn(data_dir, subset, num_shards, batch_size, use_distortion_for_training=True): """Create input graph for model. Args: data_dir: Directory where TFRecords representing the dataset are located. subset: one of 'train', 'validate' and 'eval'. num_shards: num of towers participating in data-parallel training. batch_size: total batch size for training to be divided by the number of shards. use_distortion_for_training: True to use distortions. Returns: two lists of tensors for features and labels, each of num_shards length. """ with tf.device('/cpu:0'): use_distortion = subset == 'train' and use_distortion_for_training dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion) image_batch, label_batch = dataset.make_batch(batch_size) if num_shards <= 1: # No GPU available or only 1 GPU. return [image_batch], [label_batch] # Note that passing num=batch_size is safe here, even though # dataset.batch(batch_size) can, in some cases, return fewer than batch_size # examples. This is because it does so only when repeating for a limited # number of epochs, but our dataset repeats forever. image_batch = tf.unstack(image_batch, num=batch_size, axis=0) label_batch = tf.unstack(label_batch, num=batch_size, axis=0) feature_shards = [[] for i in range(num_shards)] label_shards = [[] for i in range(num_shards)] for i in xrange(batch_size): idx = i % num_shards feature_shards[idx].append(image_batch[i]) label_shards[idx].append(label_batch[i]) feature_shards = [tf.parallel_stack(x) for x in feature_shards] label_shards = [tf.parallel_stack(x) for x in label_shards] return feature_shards, label_shards 返回两个张量列表,在这种情况下可以使用。 image_batch, label_batch = dataset.make_batch(batch_size)是示例图像的列表,如HWC张量,image_batch是该图像中的要素的颜色编码标签的图像,即HW张量。 (H =高度,W =宽度,C =通道)


如何修改weightsdataset.make_batch以提供权重张量? input_fn只能返回两个张量列表

0 个答案:
