在提供/训练模型之前,我要处理一些数据。对于此示例,我想做一个最大池2d。我写了一个简短的函数来使用tensorflow做到这一点。
import tensorflow
import tensorflow.nn as nn
def _tfMaxPool(arr, pool=(4,4), sess=None):
op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")
if sess is None:
sess = tensorflow.Session();
return sess.run(op)
问题在于,这每次都可以将节点添加到我的图形中,这似乎使我的会话混乱。一种替代方法是创建模型。
import keras
seq = keras.Sequential([
keras.layers.InputLayer((1, 512, 512)),
keras.layers.MaxPool2D((4, 4), (4, 4), data_format="channels_first")
])
def _tfMaxPool2(arr, pool=(4,4), sess=None):
swapped = arr.swapaxes(0,1)
return seq.predict(swapped).swapaxes(0,1)
该模型几乎完全符合我的需求,但是我认为我缺少一些基本知识。
答案 0 :(得分:1)
为什么不通过各种输入重复使用图形?在以下代码中,tfMaxpool仅定义一次。
def _tfMaxPool(arr, pool=(4,4)):
op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")
return op
input = tf.placeholder()
output = _tfMaxPool(input)
with tf.Session() as sess:
sess.run(output, feed_dict={input:arr1})
sess.run(output, feed_dict={input:arr2})