我想在预处理阶段使用“地图”平行旋转图像。
问题是每个图像都朝相同方向旋转(生成一个随机数后)。但是我希望每个图像都有不同的旋转度。
这是我的代码:
import tensorflow_addons as tfa
import math
import random
def rotate_tensor(image, label):
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
我试图在每次调用该函数时更改种子:
import tensorflow_addons as tfa
import math
import random
seed_num = 0
def rotate_tensor(image, label):
seed_num += 1
random.seed(seed_num)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
但是我得到了
UnboundLocalError: local variable 'seed_num' referenced before assignment
我使用的是tf2,但我认为这没什么大不了的(除了旋转图像的代码外)。
编辑:我尝试了@Mehraban的建议,但是似乎rotate_tensor函数仅被调用一次:
import tensorflow_addons as tfa
import math
import random
num_seed = 1
def rotate_tensor(image, label):
global num_seed
num_seed += 1
print(num_seed) #<---- print num_seed
random.seed(num_seed)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
但是它只打印一次“ 2”。所以我认为rotate_tensor被调用了一次。
编辑2-这是显示旋转图像的功能:
plt.figure(figsize=(12, 10))
for X_batch, y_batch in rotated_test_set.take(1):
for index in range(9):
plt.subplot(3, 3, index + 1)
plt.imshow(X_batch[index])
plt.title("Predict: {} | Actual: {}".format(class_names[y_test_proba_max_index[index]], class_names[y_batch[index]]))
plt.axis("off")
plt.show()
答案 0 :(得分:1)
问题在于如何生成随机数。尽管在处理张量流时应该使用random
,但是您依赖tf.random
模块。
这里展示了当您从tf中获得随机数时事物如何变化:
import tensorflow as tf
import random
def gen():
for i in range(10):
yield [1.]
ds = tf.data.Dataset.from_generator(gen, (float))
def m1(d):
return d*random.random()
def m2(d):
return d*tf.random.normal([])
[d for d in ds.map(m2)]
[0.17368042,
1.5629852,
1.2372143,
1.8170034,
1.7040217,
-0.16738933,
-0.11567844,
-0.17949782,
-0.67811996,
-0.5391556]
[d for d in ds.map(m1)]
[0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798]