如何删除TensorFlow自定义操作实例?

时间:2018-06-18 20:22:34

标签: python c++ tensorflow

我正在使用TensorFlow中的custom op tutorial

我在C ++中实现了自定义op并创建了一个共享库。我正在使用python tf.load_op_library函数调用加载它。然后我使用session.run()调用自定义操作。

自定义操作正常。但我无法弄清楚何时调用自定义的析构函数。

尽管我在自定义操作系统的析构函数中有一个print语句,但它永远不会打印出来。似乎自定义操作实例永远不会被破坏。

这是预期的行为吗?如果是这种情况,有没有办法通知tensorflow我使用自定义操作完成了?

请注意,仅当我们使用tf.placeholders时才会注意到此行为。如果我使输入矩阵成为常数,即

  

x = zero_out_module.zero_out([[1,2],[3,4]])

然后调用自定义op的析构函数。

C ++中的自定义op实现

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
    printf("Constructor: ZeroOutOp\n"); fflush(stdout);
  }

  ~ZeroOutOp() override {
   printf("Destructor: ZeroOutOp\n"); fflush(stdout);
  }

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

用于测试自定义操作的Python代码

import tensorflow as tf
import numpy as np

zero_out_module = tf.load_op_library('./lib/zero_out/zero_out.so')

inmat=np.array([[1,2],[3,4]])

with tf.device("/cpu:0"):
    input_mat = tf.placeholder(tf.int32)
    x=zero_out_module.zero_out(to_zero=input_mat)


sess = tf.Session(config=tf.ConfigProto(log_device_placement=True, allow_soft_placement=False))
print(sess.run(x, feed_dict={input_mat:inmat}))

0 个答案:

没有答案