假设您有一个继承数组x
的类,添加了一些参数p
:
classdef test
properties
x
p
end
methods
function t=calculate(t)
[t.x,t.p]=calc(x,p);
end
function t=plus(t1,t2)
t.x=t1.x+t2.x;
end
end
end
众所周知如何重载二元运算符,例如plus
,mtimes
,minus
等。如何为任何二进制文件启用重载向量化运算符,或最终,任何一元运算符,如mean
,abs
,max
等,以便直接应用于向量x
?例如,如何S = mean(S);
等同于S.x = mean(S.x);
?
答案 0 :(得分:1)
如果我理解你的问题,听起来你希望你的新类test
简单地继承为属性x
的类定义的所有二进制和一元方法(并且对属性进行操作) x
调用它们时,你不必自己重新定义它们。
如果这是您想要的,那么我认为唯一可行的方法就是实际使用继承并使您的班级test
成为subclass的属性类x
。考虑到x
只是double
的简单情况,可以找到子类化内置double
类型的一个很好的示例here。将该示例改编为您的示例,这是您实现班级test
的一种方式:
classdef test < double
properties
p
end
methods
function obj = test(x, p)
if (nargin < 2)
p = 0;
if (nargin < 1)
x = 0;
end
end
obj@double(x);
obj.p = p;
end
function sref = subsref(obj, s)
switch s(1).type
case '.'
switch s(1).subs
case 'p'
sref = obj.p;
case 'x'
x = double(obj);
if (length(s) < 2)
sref = x;
elseif (length(s) > 1) && strcmp(s(2).type, '()')
sref = subsref(x, s(2:end));
end
otherwise
error('Not a supported indexing expression')
end
case '()'
x = double(obj);
newx = subsref(x, s(1:end));
sref = test(newx, obj.p);
case '{}'
error('Not a supported indexing expression')
end
end
function obj = subsasgn(obj, s, b)
switch s(1).type
case '.'
switch s(1).subs
case 'p'
obj.p = b;
case 'x'
if (length(s) < 2)
obj = test(b, obj.p);
elseif (length(s) > 1) && strcmp(s(2).type, '()')
x = double(obj);
newx = subsasgn(x, s(2:end), b);
obj = test(newx, obj.p);
end
otherwise
error('Not a supported indexing expression')
end
case '()'
x = double(obj);
newx = subsasgn(x, s(1), b);
obj = test(newx, obj.p);
case '{}'
error('Not a supported indexing expression')
end
end
function disp(obj)
fprintf('p:');
disp(obj.p);
fprintf('x:');
disp(double(obj));
end
end
end
有一点需要注意:使用double
类test
对象上的double
运算符和方法得到的结果将返回类test
的结果,而不会返回x
的结果想。要获得所需的行为,您每次都必须将结果重新分配给属性>> a = test(1:3, pi) % Create an object with p = pi, and x = [1 2 3]
a =
p: 3.141592653589793
x: 1 2 3
>> a.x = -a % Unary operation on a, and reassignment to x
a =
p: 3.141592653589793
x: -1 -2 -3
>> a.x = a+4 % Binary operation and reassignment
a =
p: 3.141592653589793
x: 3 2 1
>> a.x = mean(a) % Another unary operation and reassignment
a =
p: 3.141592653589793
x: 2
,如以下示例所示:
import tensorflow as tf
import numpy as np
X_batch = np.concatenate([np.arange(10).reshape(1, -1) for i in range(10)], axis=0)
#print(X_batch )
X = tf.placeholder(dtype=tf.int32, shape=[10, 10])
with tf.Session() as sess:
print(sess.run(tf.where(X > 5, tf.fill([10, 10], 0),
X), feed_dict={X: X_batch}))