如何在Tensorflow中将2张图像发送到1个网络并计算对比度损失?

时间:2018-08-01 04:57:11

标签: python tensorflow machine-learning loss-function

我目前需要做与https://github.com/tensorflow/tensorflow/issues/6776

相似的事情

例如,我有一批图像A,B,C ....,我将为它们生成增强的批次,分别为a,b,c .....

然后我将A(B或C)发送到Inception网络中以获取输出张量“ output_1”,并且我需要将A(B或C)发送到同一Inception网络中以获取输出张量“ output_2” ,我将使用||“ output1- output2” ||作为对比损失。

目前,我不确定人们通常如何在Tensorflow中处理此类操作。我在网上搜索,但没有找到答案(尽管我想这与网络的“重用”有关)。

这是我的源代码的样子(对不起,因为我可以在此处粘贴简化版本):

class MyModel:
    ......
    def define_my_net(self):
        self.inputs_from_bloader = tf.placeholder(...)
        self.input = self.inputs_from_bloader
        self.output = slim.conv2d(self.input,...)
    ......
    def update(sess, inputs):
        feed_dict = utility.build_feed_dict(self.inputs_from_bloader, inputs)
        sess.run([my_op_list], feed_dict = feed_dict)
     ......

def train():
    data = importlib.import_module('some.datasets.reader')
    data = data.DataReader()
    model = importlib.import_module('MyModel')
    model.MyModel()
    model.define_my_net()      ### This is where network is defined
    batch = data.get_batch()    ### This is where A,B,C and a,b,c are generated.  
    model.update(sess, batch)    ### This is where training is done

我认为我可以从“ batch = data.get_batch”输出类似AaBbCc的批处理,或将其更改为“ batch1,batch2 = data.get_batch”,但是我不知道如何将batch1和batch2传递到定义的网络中,因为它可能需要对该框架进行一些架构上的修改。

如果您认为上面的源代码太乱了,那么任何更简单的示例也将不胜感激。

1 个答案:

答案 0 :(得分:0)

您可以实例化网络两次。这些实例通常称为“塔”。两个塔将使用相同的变量,但具有不同的输入和操作。

根据您使用的高级API,您应该查找一些用于控制变量reuse的标志,以便在构建第二个塔时,它不会创建新变量。例如,在https://www.tensorflow.org/guide/variables处搜索reuse