如何创建可重用的张量流操作

时间:2018-10-29 14:05:30

标签: python-3.x tensorflow keras

在提供/训练模型之前,我要处理一些数据。对于此示例,我想做一个最大池2d。我写了一个简短的函数来使用tensorflow做到这一点。

import tensorflow
import tensorflow.nn as nn

def _tfMaxPool(arr, pool=(4,4), sess=None):
    op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
    op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")
    if sess is None:
        sess = tensorflow.Session();

    return sess.run(op)

问题在于,这每次都可以将节点添加到我的图形中,这似乎使我的会话混乱。一种替代方法是创建模型。

import keras

seq = keras.Sequential([ 
            keras.layers.InputLayer((1, 512, 512)), 
            keras.layers.MaxPool2D((4, 4), (4, 4), data_format="channels_first")
                     ])
def _tfMaxPool2(arr, pool=(4,4), sess=None):
    swapped = arr.swapaxes(0,1)
    return seq.predict(swapped).swapaxes(0,1)

该模型几乎完全符合我的需求,但是我认为我缺少一些基本知识。

1 个答案:

答案 0 :(得分:1)

为什么不通过各种输入重复使用图形?在以下代码中,tfMaxpool仅定义一次。

def _tfMaxPool(arr, pool=(4,4)):
    op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
    op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")

    return op

input = tf.placeholder()
output = _tfMaxPool(input)

with tf.Session() as sess:
    sess.run(output, feed_dict={input:arr1})
    sess.run(output, feed_dict={input:arr2})