适用于单个样品的层至适用于批次的层

时间:2018-11-02 13:45:51

标签: python tensorflow

我目前有一个类似于以下的图层:

class MyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()

    def call(self, img, params):
        tf.foo(img)
        tf.bar(img, params)
        return img

调用方法获取并输入形状为img的{​​{1}}和形状为(128, 128, 3)的{​​{1}}。

我必须更改什么才能使该图层可以批量操作?输入params的形状例如为(15),而img的形状为(32, 128, 128, 3)

所以问题基本上是:我必须如何编辑图层,使其具有与现在相同的功能,但对于批处理中的每个图像呢?

2 个答案:

答案 0 :(得分:0)

也许您只是想重塑图像张量的大小? 您想要做的事情将是x = np.reshape(x,(32,128,128,3))。 莱姆知道这是否有帮助。

答案 1 :(得分:0)

默认情况下,您的图层应识别来自上一层的张量,并能够确保其输入形状符合该规范: 即

(None, 128,128,3)

因此批次大小无关紧要。您的问题很可能是您没有build()函数,并且没有通过init及其上级传递** kwargs。

此外,call()应该仅以self和x作为参数,如果需要多个输入,则应在使用连接层之前将两个张量推在一起,或与它们一起列出