我正在尝试对批量大小为4的小批量图像进行数据扩充。(仅用于测试目的)
sess = tf.Session()
#Create dataset
dataset = get_dataset()
#Set seed placeholder
seedin = tf.placeholder(tf.int64,shape=())
#Get iterator
iterator = create_next_batch_iterator(dataset,seedin)
#Initialize the Iterator
sess.run(iterator.initializer,feed_dict={seedin:6})
#Get next batch
next_batch = get_next_batch(iterator)
print next_batch
print 'without augmentation:',next_batch['labels'].eval(session=sess)
我得到的输出是
{'images': <tf.Tensor 'IteratorGetNext:0' shape=(4, 96, 96, 3) dtype=uint8>,
'labels': <tf.Tensor 'IteratorGetNext:1' shape=(4,) dtype=uint8>}
without augmentation: [6 1 7 6]
但如果我用下一行代码替换最后一行
next_batch = augment_data(get_next_batch(iterator),sess)
print next_batch
问题从这里开始......
augment_data
功能的代码
def augment_data(batch,sess,naug=5):
labels_tensor = batch['labels'].eval(session=sess)
print labels_tensor
labels_array = np.array(batch['labels'].eval(session=sess))
print labels_array
我为labels_tensor
和labels_array
[6 1 7 6]
[9 3 8 4]
这些值与执行
时获得的值相同next_batch = get_next_batch(iterator)
print next_batch
连续两次。
看来,每当我试图从batch
获取张量时,get_next_batch()
函数就会被执行,这就是我得到一组新值的原因。
为什么? 我该如何解决这个问题?