Tensorflow - 自定义操作 - 将参数作为可变对象(如通过引用传递)

时间:2018-03-19 22:01:29

标签: python c++ tensorflow pass-by-reference

我尝试在张量流中开发一个自定义操作,在此期间我更改属于输入张量的值,并且在操作完成后仍然应该反映这些更改。

挑战在于,python通过赋值传递函数参数(在我们的例子中是== immutable类型),因此阻止我访问我实际想要在op中访问的数据。

最小工作示例:

test_script.py

img = imread("some image path");
img = ((img_gt[:, : , 0:3]).flatten()).astype('float32')

img[0] = 50
print "First time (start value): ", img[0]

Custom_Loss_Module = tf.load_op_library('some shared library')

with tf.Session(''):
    Custom_Loss_Module.custom_loss(img)
    print "Fourth time (after custom op): ", img[0]

custom_op_main.cpp

REGISTER_OP("CustomLoss")
    .Input("mat: float")


class CustomLossOp : public OpKernel
{
    public:
        explicit CustomLossOp(OpKernelConstruction* context) : OpKernel(context) {}


        void Compute(OpKernelContext* context) override
        {
            Tensor & tensor_img = const_cast<Tensor &>(context->input(0));

            float* pointer_img = reinterpret_cast<float*>(const_cast<char*>((tensor_img.tensor_data()).data()));

            std::cout << "Second time (in op): " << pointer_img[0] << "\n" << std::flush;

            pointer_img[0] = 100;

            std::cout << "Third time (in op): " << pointer_img[0] << "\n" << std::flush;
        } 
};

输出:

First time (start value):       50
Second time (in op):            50
Third time (in op):             100
Fourth time (after custom op):  50   <-- I want to see a 100 here :(

现在如上所述,我知道这种行为来自哪里,我知道这是可以预期的。但是,我想找到一种方法,以便从50改为100(在自定义op中完成)是可变的,因此一旦自定义op被finsihed,img-tensor就会反映出来。在C ++中 - 术语:我想通过引用而不是赋值来传递函数参数。

到目前为止我尝试了什么:

我尝试使用标准的python技巧并将参数作为列表传递,因为列表是可变的。例如:

test_script.py

img = imread("some image path");
img = ((img_gt[:, : , 0:3]).flatten()).astype('float32')

img[0] = 50
print "First time (start value): ", img[0]

Custom_Loss_Module = tf.load_op_library('some shared library')

with tf.Session(''):
    Custom_Loss_Module.custom_loss([img])
    print "Fourth time (after custom op): ", img[0]

但是,我没有让它工作,并且在讨论之后:https://github.com/tensorflow/tensorflow/issues/9334,我也相信自定义操作不能将列表作为输入参数(即使这样文档提到它可以... 。:S。请参阅此链接中的“Attr类型”:https://www.tensorflow.org/extend/adding_an_op

所以有人可以帮助我解决这个问题吗? - 提前Thx:)

1 个答案:

答案 0 :(得分:0)

框架不允许突变张量流张量,因此您会得到奇怪的未定义行为。不要那样做;张量流张量是不变的。

相反,制作一个使用输入张量并返回输出张量的op,该输出张量是输入的副本,但有些变化。或使用变量(但这会丢失计算的可区分性)。