tf.cond表现不尽如人意

时间:2016-05-22 12:47:01

标签: c++ tensorflow

我有一个与tensorflow如何评估tf.cond语句中的表达式有关的问题。

我使用自定义操作,为我提供A类或B类数据。它的签名看起来像

REGISTER_OP("GetData")
.Attr("path: string")
.Output("data: int32")
.Output("type: int32")
.SetIsStateful()

类型A的type输出为1,类型B的输出为2.根据类型,我想运行不同的(自定义)操作,例如opAopB。当然,它们的输出具有相同的类型。为表达此数据流,我使用tf.cond,如下所示:

(data, type) = get_data(path = "...")

opA = op_a(data)
opB = op_b(dat

def perfromA():  return opA
def perfromB():  return opB

joined_op = tf.cond(tf.equal(type, tf.constant(1, dtype=tf.int32)), perfromA, perfromB)

我向getDataopAopB添加了一些调试语句。对于数据序列 ABA ,我希望得到

getData: returning type A
opA:     got some data
getData: returning type B
opB:     got some data   
getData: returning type A
opA:     got some data

然而,我确实得到了

getData: returning type A
opA:     got some data
opB:     got some data
getData: returning type B
opA:     got some data
opB:     got some data   
getData: returning type A
opA:     got some data
opB:     got some data

这是预期的行为吗? joined_op的结果确实正确考虑了if-else,但仍然会计算两个操作。这不仅是计算负担,而且如果opAopA执行影响变量的操作(例如优化步骤),也会失败。

@Yaroslav指出

正确的解决方案

def perfromA():  return op_a(data)
def perfromB():  return op_b(data)

有关更详细的示例,请参阅以下三个操作的虚拟实现:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"

using namespace std;
using namespace tensorflow;


REGISTER_OP("GetData")
.Attr("path: string")
.Output("data: int32")
.Output("type: int32")
.SetIsStateful();


class GetData : public OpKernel {
    public:
        explicit GetData(OpKernelConstruction* ctx) : OpKernel(ctx) {}

        void Compute(OpKernelContext* ctx) override
        {
            Tensor data(DT_INT32, TensorShape({}));
            Tensor type(DT_INT32, TensorShape({}));
            if(rng_.Uniform(2) == 0){
                type.scalar<int32>()() = 1;
                data.scalar<int32>()() = 100;
                LOG(INFO) << "returning type A";
            }
            else{
                type.scalar<int32>()() = 2;
                data.scalar<int32>()() = 200;
                LOG(INFO) << "returning type B";
            }

            ctx->set_output(0, data);
            ctx->set_output(1, type);
        }

    private:
        random::PhiloxRandom philox_ =  random::PhiloxRandom(10) ;
        random::SimplePhilox rng_ = random::SimplePhilox(&philox_);

};
REGISTER_KERNEL_BUILDER(Name("GetData").Device(DEVICE_CPU), GetData);


REGISTER_OP("OpA")
.Input("input: int32")
.Output("output: int32");

class OpA : public OpKernel {
    public:
        explicit OpA(OpKernelConstruction* ctx) : OpKernel(ctx) {}

        void Compute(OpKernelContext* ctx) override
        {
              const Tensor& data = ctx->input(0);
              LOG(INFO) << "A: got some data";
              ctx->set_output(0, data);
        }
};
REGISTER_KERNEL_BUILDER(Name("OpA").Device(DEVICE_CPU), OpA);

REGISTER_OP("OpB")
.Input("input: int32")
.Output("output: int32");

class OpB : public OpKernel {
    public:
        explicit OpB(OpKernelConstruction* ctx) : OpKernel(ctx) {}

        void Compute(OpKernelContext* ctx) override
        {
              const Tensor& data = ctx->input(0);
              LOG(INFO) << "B: got some data";
              ctx->set_output(0, data);
        }
};
REGISTER_KERNEL_BUILDER(Name("OpB").Device(DEVICE_CPU), OpB);

1 个答案:

答案 0 :(得分:0)

你必须在tf.cond

中创建opA / opB

参见讨论here