以下是代码:
import tensorflow as tf
import numpy as np
import time
index1 = tf.Variable(-1, dtype=tf.int32, trainable=False)
index2 = tf.Variable(0, dtype=tf.int32, trainable=False)
starting_point1 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point2 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point3 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point4 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
starting_point5 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
def starting_point_add_1_0(starting_point1, n):
assignment1 = tf.assign_add(starting_point1, n)
assignment2 = tf.assign_add(starting_point1, 0)
return assignment1, assignment2
mod_op1 = tf.mod(index1, 5)
mod_op1 = tf.Print(input_=mod_op1, data=[mod_op1], message="mod_op1")
mod_op2 = tf.mod(index2, 5)
mod_op2 = tf.Print(input_=mod_op2, data=[mod_op2], message="mod_op2")
condition = tf.logical_and(tf.equal(mod_op1, 0), tf.equal(mod_op2, 1))
condition = tf.Print(input_=condition, data=[condition], message="condition")
ass1_1, ass1_2 = tf.cond(condition,
lambda: starting_point_add_1_0(starting_point1, 1),
lambda: starting_point_add_1_0(starting_point1, 0))
# ass2 = tf.cond(tf.cast(index1 % 5 == 1 or index2 % 5 == 2, dtype=tf.bool),
# lambda: tf.assign_add(starting_point2, 1),
# lambda: tf.assign_add(starting_point2, 0))
# ass3 = tf.cond(index1 % 5 == 2 or index2 % 5 == 3,
# lambda: tf.assign_add(starting_point3, 1),
# lambda: tf.assign_add(starting_point3, 0))
# ass4 = tf.cond(index1 % 5 == 3 or index2 % 5 == 4,
# lambda: tf.assign_add(starting_point4, 1),
# lambda: tf.assign_add(starting_point4, 0))
# ass5 = tf.cond(index1 % 5 == 4 or index2 % 5 == 0,
# lambda: tf.assign_add(starting_point5, 1),
# lambda: tf.assign_add(starting_point5, 0))
data1 = tf.data.Dataset.range(1, 20).skip(starting_point1)
data2 = tf.data.Dataset.range(21, 40).skip(starting_point2)
data3 = tf.data.Dataset.range(41, 60).skip(starting_point3)
data4 = tf.data.Dataset.range(61, 80).skip(starting_point4)
data5 = tf.data.Dataset.range(81, 100).skip(starting_point5)
iterator1 = data1.make_initializable_iterator()
iterator2 = data2.make_initializable_iterator()
iterator3 = data3.make_initializable_iterator()
iterator4 = data4.make_initializable_iterator()
iterator5 = data5.make_initializable_iterator()
d1 = iterator1.get_next()
d2 = iterator2.get_next()
d3 = iterator3.get_next()
d4 = iterator4.get_next()
d5 = iterator5.get_next()
data_ = tf.stack((d1, d2, d3, d4, d5), axis=0)
ass6 = tf.assign_add(index1, 1)
ass7 = tf.assign_add(index2, 1)
with tf.control_dependencies([ass6, ass7]):
data = tf.gather_nd(data_, indices=[[index1 % 5], [index2 % 5]])
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
sess.run(iterator1.initializer)
sess.run(iterator2.initializer)
sess.run(iterator3.initializer)
sess.run(iterator4.initializer)
sess.run(iterator5.initializer)
try:
for i in range(20):
t1, t2, t3, s1 = sess.run([data, index1, index2, starting_point1])
print(t1, t2, t3, ".....", s1)
sess.run([ass1_1, ass1_2])
except tf.errors.OutOfRangeError:
print("error")
因此,此代码的主要目的是能够使用tf.data.Dataset.range
函数迭代我创建的5个不同数据集。我想从第一个数据集中获取1个元素,从第二个数据集中获取1个元素。那么,我想考虑data2
和data3
,那么data3
和data4
,那么data4
和data5
,那么data5
和data1
等等。
我已经考虑过这种方式了,这是输出:
[ 1 21] 0 1 ..... 0
[22 42] 1 2 ..... 1
[43 63] 2 3 ..... 1
[64 84] 3 4 ..... 1
[85 5] 4 5 ..... 1
[ 6 26] 5 6 ..... 1
[27 47] 6 7 ..... 2
[48 68] 7 8 ..... 2
[69 89] 8 9 ..... 2
[90 10] 9 10 ..... 2
[11 31] 10 11 ..... 2
[32 52] 11 12 ..... 3
[53 73] 12 13 ..... 3
[74 94] 13 14 ..... 3
[95 15] 14 15 ..... 3
[16 36] 15 16 ..... 3
[37 57] 16 17 ..... 4
[58 78] 17 18 ..... 4
[79 99] 18 19 ..... 4
error
现在我只测试第一个数据集的起点,看看它是否正常工作,但我得到了上述结果。我以为我会得到一个2
而不是5
,但事实并非如此。好像跳过几乎没有效果。
我想如果有人可以帮助我解决这个问题。
最后,我正在研究这个,因为我有一个不适合GPU内存的自定义数据集,因此我希望实现类似的东西。
我为skip()
函数尝试了以下测试,结果如下:
starting_point1 = tf.Variable(tf.constant(2, dtype=tf.int64), trainable=False)
data1 = tf.data.Dataset.range(1, 20).skip(starting_point1)
iterator1 = data1.make_initializable_iterator()
d1 = iterator1.get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(iterator1.initializer)
try:
for i in range(10):
print(sess.run(d1))
except tf.errors.OutOfRangeError:
print("error")
输出:
3
4
5
6
7
8
9
10
11
12
因此,如何强制迭代器在每个训练步骤中跳过?
非常感谢任何帮助!!