我正在使用GitHub中的cifar10_estimator示例来构建用于语义分段的分布式TensorFlow模型,并且在修改cifar10_main.py中的input_fn函数时遇到问题
import Tkinter as tk
from PIL import ImageTk, Image
path = 'C:/xxxx/xxxx.jpg'
root = tk.Tk()
img = ImageTk.PhotoImage(Image.open(path))
panel = tk.Label(root, image = img)
panel.pack(side = "bottom", fill = "both", expand = "yes")
root.mainloop()
当前def input_fn(data_dir,
subset,
num_shards,
batch_size,
use_distortion_for_training=True):
"""Create input graph for model.
Args:
data_dir: Directory where TFRecords representing the dataset are located.
subset: one of 'train', 'validate' and 'eval'.
num_shards: num of towers participating in data-parallel training.
batch_size: total batch size for training to be divided by the number of
shards.
use_distortion_for_training: True to use distortions.
Returns:
two lists of tensors for features and labels, each of num_shards length.
"""
with tf.device('/cpu:0'):
use_distortion = subset == 'train' and use_distortion_for_training
dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
image_batch, label_batch = dataset.make_batch(batch_size)
if num_shards <= 1:
# No GPU available or only 1 GPU.
return [image_batch], [label_batch]
# Note that passing num=batch_size is safe here, even though
# dataset.batch(batch_size) can, in some cases, return fewer than batch_size
# examples. This is because it does so only when repeating for a limited
# number of epochs, but our dataset repeats forever.
image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
feature_shards = [[] for i in range(num_shards)]
label_shards = [[] for i in range(num_shards)]
for i in xrange(batch_size):
idx = i % num_shards
feature_shards[idx].append(image_batch[i])
label_shards[idx].append(label_batch[i])
feature_shards = [tf.parallel_stack(x) for x in feature_shards]
label_shards = [tf.parallel_stack(x) for x in label_shards]
return feature_shards, label_shards
返回两个张量列表,在这种情况下可以使用。 image_batch, label_batch = dataset.make_batch(batch_size)
是示例图像的列表,如HWC张量,image_batch
是该图像中的要素的颜色编码标签的图像,即HW张量。 (H =高度,W =宽度,C =通道)
对于label_batch
中的损失函数,我现在需要提供class_weights来调用_tower_fn
。 tf.losses.sparse_softmax_cross_entropy
和labels
参数的尺寸必须与硬件张量相同。
如何修改weights
和dataset.make_batch
以提供权重张量? input_fn
只能返回两个张量列表