在自定义损失函数中将张量转换为numpy数组

时间:2020-06-04 23:21:25

标签: python numpy tensorflow keras tensor

这是我的自定义损失函数:

import tensorflow.keras.backend as K
import cmath

epsylon=np.finfo(float).eps

def to_array(tensor):
    return tf.make_ndarray(tensor)


def addError(test,range_min,range_max,result):
    err = abs(log(range_max/max(range_min,epsylon)))
    if range_min <= test <= range_max:
        result.append(err)
    else:
        e1=abs(log(test/max(range_min,epsylon)))
        e2=abs(log(test/max(range_max,epsylon)))
        result.append( min(e1,e2) / max(err,epsylon) *100 + err)


def rangeLoss(yTrue,yPred):
    #print(type(yPred))
    a_pred=to_array(yPred)
    a_true=to_array(yTrue)

    result=[]

    for i in range(a_true.size):
        range_min=abs(a_pred[i*2])
        range_max=abs(a_pred[i*2+1])
        test= abs(a_true[i])

        addError(test,range_min,range_max,result)


    return tf.constant(result)

我进行训练时,它失败并显示

/home/ubuntu/.local/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py:591 MakeNdarray
        shape = [d.size for d in tensor.tensor_shape.dim]

    AttributeError: 'Tensor' object has no attribute 'tensor_shape'

当我修改to_array以使用原始张量时

def to_array(tensor):
    proto_tensor = tf.make_tensor_proto(tensor)
    return tf.make_ndarray(proto_tensor)

我收到以下错误:

    /home/ubuntu/.local/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py:451 make_tensor_proto
        _AssertCompatible(values, dtype)
    /home/ubuntu/.local/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py:328 _AssertCompatible
        raise TypeError("Expected any non-tensor type, got a tensor instead.")

    TypeError: Expected any non-tensor type, got a tensor instead.

我尝试过的另一个选项是tensor.numpy(),它导致了以下错误:

    <ipython-input-20-0a8051a4a034>:8 to_array
        return tensor.numpy()

    AttributeError: 'Tensor' object has no attribute 'numpy'

当然还有tensor.eval(session=tf.compat.v1.Session()),它也失败了

我该怎么做?

1 个答案:

答案 0 :(得分:0)

我已经通过切片原始张量解决了这个问题。这是代码:

<h1 id="age" data-age="27">27</h1>