如何在使用Keras的计算图中将张量馈入预训练模型?

时间:2019-03-05 11:35:59

标签: python tensorflow keras

我想使用GAN在生成器的末尾训练具有确定性约束的特定条件Keras,为此,我需要首先使用{{1]计算生成器输出的嵌入}}。

我正在使用VGG-16 pre-trained model

在我的计算图中,我想将生成器的输出python 3.6馈送到预先训练的VGG-16模型中,以获取嵌入。

由于我在计算图中,因此我的img就是形状的张量(无,224,224,3)。问题是,如果我编译以下内容,则会出现错误

  

当将符号张量馈送到模型时,我们期望张量能够   具有静态批次大小。得到了张量形状:(None,224,224,3)

img

很显然,我无法沿第一个轴循环,因为它是None索引。我试图使用tensorflow函数'tf.map_fn'将函数应用于此'img'张量,如下所示:

self.vgg = self.build_vgg()

def build_vgg(self):
    vgg16_model = keras.applications.vgg16.VGG16()
    return Model(inputs=vgg16_model.input,outputs=vgg16_model.get_layer('fc2').output)

        #-------------------------------
    # Construct Computational Graph
    #         for Generator
    #-------------------------------

    # For the generator we freeze the critic's layers
    self.critic.trainable = False
    self.generator.trainable = True
    self.vgg.trainable = False


    # Sampled noise for input to generator
    noise = Input(shape=(self.latent_dim,))

    # Input Embedding:
    embedding = Input(shape=(self.embedding,))


    # Generate images based of noise

    img = self.generator([noise,embedding])

    # Discriminator determines validity

    valid = self.critic(img)

    # Get the embeddings from vgg-16:
    X = self.vgg.predict(img)

但是我收到以下错误消息:

  

ValueError:设置具有序列的数组元素。

回顾一下,我想在 def Embedding(self,img): fn = lambda x: self.vgg.predict(preprocess_input(np.expand_dims(x, axis=0))).flatten() embedding = tf.map_fn(fn,img,dtype=tf.float32) return embedding #------------------------------- # Construct Computational Graph # for Generator #------------------------------- # For the generator we freeze the critic's layers self.critic.trainable = False self.generator.trainable = True self.vgg.trainable = False # Sampled noise for input to generator noise = Input(shape=(self.latent_dim,)) # Input Embedding: embedding = Input(shape=(self.embedding,)) # Generate images based of noise img = self.generator([noise,embedding]) # Discriminator determines validity valid = self.critic(img) # Get the embeddings from VGG16 X = self.Embedding(img) 的计算图中沿Batch_Size轴(0)形状为(None,224,224,3)的pre-trained VGG-16 model上应用tensor。我之前向您解释的是我已经尝试过的...

有人对此有任何建议吗?

0 个答案:

没有答案