Tensorflow - 应用超过1D张量的函数

时间:2016-11-28 13:12:37

标签: tensorflow

我有一个函数dice

def dice(yPred,yTruth,thresh):
    smooth = tf.constant(1.0)
    threshold = tf.constant(thresh)
    yPredThresh = tf.to_float(tf.greater_equal(yPred,threshold))
    mul = tf.mul(yPredThresh,yTruth)
    intersection = 2*tf.reduce_sum(mul) + smooth
    union = tf.reduce_sum(yPredThresh) + tf.reduce_sum(yTruth) + smooth
    dice = intersection/union
    return dice, yPredThresh

哪个有效。这里给出一个例子

with tf.Session() as sess:

    thresh = 0.5 
    print("Dice example")
    yPred = tf.constant([0.1,0.9,0.7,0.3,0.1,0.1,0.9,0.9,0.1],shape=[3,3])
    yTruth = tf.constant([0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0],shape=[3,3])
    diceScore, yPredThresh= dice(yPred=yPred,yTruth=yTruth,thresh= thresh)

    diceScore_ , yPredThresh_ , yPred_, yTruth_ = sess.run([diceScore,yPredThresh,yPred, yTruth])
    print("\nScore = {0}".format(diceScore_))

>>> Score = 0.899999976158

我希望能够绕过骰子的第三个争论,打谷。我不知道这样做的最好方法,我可以从图中提取它。以下内容......

def diceROC(yPred,yTruth,thresholds=np.linspace(0.1,0.9,20)):
    thresholds = thresholds.astype(np.float32)
    nThreshs = thresholds.size
    diceScores = tf.zeros(shape=nThreshs)

    for i in xrange(nThreshs):
        score,_ = dice(yPred,yTruth,thresholds[i])
        diceScores[i] = score
    return diceScores

评估diceScoreROC会产生错误'Tensor' object does not support item assignment,因为我无法循环并切片显示tf张量。

1 个答案:

答案 0 :(得分:1)

我鼓励你使用张量流的广播能力,而不是循环。如果您将dice重新定义为:

def dice(yPred,yTruth,thresh):
    smooth = tf.constant(1.0)
    yPredThresh = tf.to_float(tf.greater_equal(yPred,thresh))
    mul = tf.mul(yPredThresh,yTruth)
    intersection = 2*tf.reduce_sum(mul, [0, 1]) + smooth
    union = tf.reduce_sum(yPredThresh, [0, 1]) + tf.reduce_sum(yTruth, [0, 1]) + smooth
    dice = intersection/union
    return dice, yPredThresh

您将能够传递三维yPredyTruth(假设张量将在最后一个维度重复)和一维thresh

with tf.Session() as sess:

    thresh = [0.1,0.9,20, 0.5]
    print("Dice example")
    yPred = tf.constant([0.1,0.9,0.7,0.3,0.1,0.1,0.9,0.9,0.1],shape=[3,3,1])
    ypred_tiled = tf.tile(yPred, [1,1,4])
    yTruth = tf.constant([0.0,1.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0],shape=[3,3,1])
    ytruth_tiled = tf.tile(yTruth, [1,1,4])
    diceScore, yPredThresh= dice(yPred=ypred_tiled,yTruth=ytruth_tiled,thresh= thresh)

    diceScore_ = sess.run(diceScore)
    print("\nScore = {0}".format(diceScore_))

你会得到:

Score = [ 0.73333335  0.77777779  0.16666667  0.89999998]