我下面有两个代码集:
其中一个标题为“工作代码”。 另一个标题为“非工作代码”。
在工作代码中,我使用普通函数调用来获取图像大小(在代码中被注释为“普通编码-END:普通编码”)。
在非工作代码中,我将其替换为tensorflow迭代器(注释为“在代码中使用Use Iterator-END:Use Iterator)。
但是,使用tensorflow迭代器后,出现以下错误消息:
arg = tf.shape(image)[0] .eval() .........在_eval_using_default_session中 引发ValueError(“无法使用默认会话评估张量:” ValueError:无法使用默认会话评估张量:张量图与会话图不同。将显式会话传递给eval(session = sess)。
换句话说,使用tensorflow迭代器后,get_patches()中代码'arg = tf.shape(image)[0] .eval()'中的tensorflow eval()不再起作用。使用迭代器后,似乎没有正确初始化get_patches()。有没有办法解决这个问题?
### Working Code
import tensorflow as tf
import numpy as np
from glob import glob
p = print
class denoiser(object):
def __init__(self, sess):
self.sess = sess
self.y = tf.placeholder("int32", None)
self.z = self.y + 30
init_op = tf.global_variables_initializer()
self.dataset = dataset(self.sess)
init = tf.global_variables_initializer()
self.sess.run(init)
def train(self, sess):
my_size = self.dataset.obtain_size()
size = sess.run(self.z, feed_dict={self.y: 4}) + my_size + 3000
p('size:', size)
class dataset(object):
def __init__(self, sess):
self.sess = sess
filename = tf.convert_to_tensor('../Data/SIDD_Medium_Srgb/Data/train_clean/0001_GT_SRGB_011.PNG', 'string')
## Plain Coding
image = im_read(filename)
self.my_size = get_patches(image)
## END: Plain Coding
def obtain_size(self):
size = self.sess.run(self.my_size)
return size
def im_read(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_png(image_string, channels=3)
image = tf.image.convert_image_dtype(image_decoded, tf.float32)
return image
def get_patches(image):
arg = tf.shape(image)[0].eval()
size = 30 + arg
size = tf.convert_to_tensor(size)
return size
def main():
with tf.Session() as sess:
model = denoiser(sess)
denoiser.train(model, sess)
main()
### Non-working Code
import tensorflow as tf
import numpy as np
from glob import glob
p = print
class denoiser(object):
def __init__(self, sess):
self.sess = sess
self.y = tf.placeholder("int32", None)
self.z = self.y + 30
init_op = tf.global_variables_initializer()
self.dataset = dataset(self.sess)
init = tf.global_variables_initializer()
self.sess.run(init)
def train(self, sess):
my_size = self.dataset.obtain_size()
size = sess.run(self.z, feed_dict={self.y: 4}) + my_price + 3000
p('size:', size)
class dataset(object):
def __init__(self, sess):
self.sess = sess
filename = [tf.convert_to_tensor('../Data/SIDD_Medium_Srgb/Data/train_clean/0001_GT_SRGB_011.PNG', 'string')]
## Use Iterator
get_patches_fn = lambda image: get_patches(image)
data = (
tf.data.Dataset.from_tensor_slices(filename)
.map(im_read, num_parallel_calls=1)
.map(get_patches_fn, num_parallel_calls=1)
.batch(1)
.prefetch(1)
)
iterator = data.make_one_shot_iterator()
self.iter = iterator.get_next()
## END: Use Iterator
def obtain_size(self):
size= self.sess.run(self.iter)
return size
def im_read(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_png(image_string, channels=3)
image = tf.image.convert_image_dtype(image_decoded, tf.float32)
return image
def get_patches(image):
arg = tf.shape(image)[0].eval()
size = 30 + arg
size = tf.convert_to_tensor(size)
return size
def main():
with tf.Session() as sess:
model = denoiser(sess)
denoiser.train(model, sess)
main()