我正在尝试添加一个自定义调整大小图层,该图层没有固定的调整大小值,而是从输入图层获取缩放值。
我发现了这个,但是它有一个固定的调整大小值:Add a resizing layer to a keras sequential model
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow.keras.backend as K
class Resize(Layer):
def init(self):
super(Resize,self).__init__()
def build(self,input_shape):
super(Resize,self).build(input_shape)
def call(self, x, size):
out = tf.image.resize(x,size=size)
return out
def get_output_shape_for(self, input_shape):
return (None,None,3)
inp = Input((10,10,3))
size = Input((1,), dtype='int32')
out = Resize()(inp, size=(100,100)) #(inp, size=(size,size))
model = Model([inp,size], out)
model.summary()
当我尝试这样做时:
inp = Input((10,10,3))
size = Input((1,), dtype='int32')
out = Resize()(inp, size=(size,size))
model = Model([inp,size], out)
model.summary()
错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/keras/api/_v1/keras/models/__init__.py in <module>()
2 size = Input((1,), dtype='int32')
3
----> 4 out = Resize()(inp, size=(size,size)) #(inp, size=(size,size))
5
6 model = Model([inp,size], out)
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
634 outputs = base_layer_utils.mark_as_return(outputs, acd)
635 else:
--> 636 outputs = call_fn(inputs, *args, **kwargs)
637
638 except TypeError as e:
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
147 except Exception as e: # pylint:disable=broad-except
148 if hasattr(e, 'ag_error_metadata'):
--> 149 raise e.ag_error_metadata.to_exception(type(e))
150 else:
151 raise
ValueError: in converted code:
<ipython-input-1-ab7021ffbc7d>:14 call *
out = tf.image.resize(x,size=size)
/home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/ops/image_ops_impl.py:1182 resize_images
skip_resize_if_same=True)
/home/ec2-user/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/ops/image_ops_impl.py:1045 _resize_images_common
raise ValueError('\'size\' must be a 1-D Tensor of 2 elements: '
ValueError: 'size' must be a 1-D Tensor of 2 elements: new_height, new_width
答案 0 :(得分:0)
一种解决方法是设置size=Input(tensor=K.variable([2,2], dtype=tf.int32))
。
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow.keras.backend as K
class Resize(Layer):
def init(self):
super(Resize,self).__init__()
def build(self,input_shape):
super(Resize,self).build(input_shape)
def call(self, inputs):
x = inputs[0]
size = inputs[1]
out = tf.image.resize(x,size=size)
return out
def get_output_shape_for(self, input_shape):
return (None,None,3)
inp = Input((10,10,3))
var_size = K.variable([2,2], dtype=tf.int32)
size = Input(tensor=var_size, name='size')
out = Resize()([inp, size])
model = Model([inp,size], out)
model.summary()
# Model: "model"
# __________________________________________________________________________________________________
# Layer (type) Output Shape Param # Connected to
# ==================================================================================================
# input_1 (InputLayer) [(None, 10, 10, 3)] 0
# __________________________________________________________________________________________________
# input_2 (InputLayer) [(2,)] 0
# __________________________________________________________________________________________________
# resize (Resize) (None, None, None, 3 0 input_1[0][0]
# input_2[0][0]
# ==================================================================================================
# Total params: 0
# Trainable params: 0
# Non-trainable params: 0
input_mat = np.random.randn(100,10,10,3)
K.set_value(var_size, [5,5])
res = model.predict({'x': input_mat})
# res.shape (100,5,5,3)
K.set_value(var_size, [3,3])
res = model.predict({'x': input_mat})
# res.shape (100,3,3,3)