我有一批图像作为尺寸[batch_size, w, h]
的张量。
我希望获得每列中值的直方图。
这就是我提出的(但它仅适用于批处理中的第一张图像,而且速度也非常慢):
global_hist = []
net = tf.squeeze(net)
for i in range(batch_size):
for j in range(1024):
hist = tf.histogram_fixed_width(tf.slice(net,[i,0,j],[1,1024,1]), [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], nbins=10)
global_hist[i].append(hist)
有一种有效的方法吗?
答案 0 :(得分:0)
nbins=10
net = tf.squeeze(net)
for i in range(batch_size):
local_hist = tf.expand_dims(tf.histogram_fixed_width(tf.slice(net,[i,0,0],[1,1024,1]), [0.0, 1.0], nbins=nbins, dtype=tf.float32),[-1])
for j in range(1,1024):
hist = tf.histogram_fixed_width(tf.slice(net,[i,0,j],[1,1024,1]), [0.0, 1.0], nbins=nbins, dtype=tf.float32)
hist = tf.expand_dims(hist,[-1])
local_hist = tf.concat(1, [local_hist, hist])
if i==0:
global_hist = tf.expand_dims(local_hist, [0])
else:
global_hist = tf.concat(0, [global_hist, tf.expand_dims(local_hist,[0])])
此外,我发现此链接非常有用 https://stackoverflow.com/questions/41764199/row-wise-histogram/41768777#41768777