如何为类重载任意二元或一元运算符?

时间:2017-10-07 21:47:13

标签: matlab class operator-overloading

假设您有一个继承数组x的类,添加了一些参数p

classdef test
    properties
       x
       p
    end
    methods
       function t=calculate(t)
           [t.x,t.p]=calc(x,p);
       end
       function t=plus(t1,t2)
           t.x=t1.x+t2.x;
       end
   end
end

众所周知如何重载二元运算符,例如plusmtimesminus等。如何为任何二进制文件启用重载向量化运算符,或最终,任何一元运算符,如meanabsmax等,以便直接应用于向量x?例如,如何S = mean(S);等同于S.x = mean(S.x);

1 个答案:

答案 0 :(得分:1)

如果我理解你的问题,听起来你希望你的新类test简单地继承为属性x的类定义的所有二进制和一元方法(并且对属性进行操作) x调用它们时,你不必自己重新定义它们。

如果这是您想要的,那么我认为唯一可行的方法就是实际使用继承并使您的班级test成为subclass的属性类x。考虑到x只是double的简单情况,可以找到子类化内置double类型的一个很好的示例here。将该示例改编为您的示例,这是您实现班级test的一种方式:

classdef test < double

  properties
    p
  end

  methods
    function obj = test(x, p)
      if (nargin < 2)
        p = 0;
        if (nargin < 1)
          x = 0;
        end
      end
      obj@double(x);
      obj.p = p;
    end

    function sref = subsref(obj, s)
      switch s(1).type
        case '.'
          switch s(1).subs
            case 'p'
              sref = obj.p;
            case 'x'
              x = double(obj);
              if (length(s) < 2)
                sref = x;
              elseif (length(s) > 1) && strcmp(s(2).type, '()')
                sref = subsref(x, s(2:end));
              end
            otherwise
              error('Not a supported indexing expression')
          end
        case '()'
          x = double(obj);
          newx = subsref(x, s(1:end));
          sref = test(newx, obj.p);
        case '{}'
          error('Not a supported indexing expression')
      end
    end

    function obj = subsasgn(obj, s, b)
      switch s(1).type
        case '.'
          switch s(1).subs
            case 'p'
              obj.p = b;
            case 'x'
              if (length(s) < 2)
                obj = test(b, obj.p);
              elseif (length(s) > 1) && strcmp(s(2).type, '()')
                x = double(obj);
                newx = subsasgn(x, s(2:end), b);
                obj = test(newx, obj.p);
              end
            otherwise
              error('Not a supported indexing expression')
          end
        case '()'
          x = double(obj);
          newx = subsasgn(x, s(1), b);
          obj = test(newx, obj.p);
        case '{}'
          error('Not a supported indexing expression')
      end
    end

    function disp(obj)
      fprintf('p:');
      disp(obj.p);
      fprintf('x:');
      disp(double(obj));
    end
  end

end

有一点需要注意:使用doubletest对象上的double运算符和方法得到的结果将返回类test的结果,而不会返回x的结果想。要获得所需的行为,您每次都必须将结果重新分配给属性>> a = test(1:3, pi) % Create an object with p = pi, and x = [1 2 3] a = p: 3.141592653589793 x: 1 2 3 >> a.x = -a % Unary operation on a, and reassignment to x a = p: 3.141592653589793 x: -1 -2 -3 >> a.x = a+4 % Binary operation and reassignment a = p: 3.141592653589793 x: 3 2 1 >> a.x = mean(a) % Another unary operation and reassignment a = p: 3.141592653589793 x: 2 ,如以下示例所示:

import tensorflow as tf
import numpy as np


X_batch = np.concatenate([np.arange(10).reshape(1, -1) for i in range(10)], axis=0)

#print(X_batch )

X = tf.placeholder(dtype=tf.int32, shape=[10, 10])

with tf.Session() as sess:
    print(sess.run(tf.where(X > 5, tf.fill([10, 10], 0),
                            X), feed_dict={X: X_batch}))