Theano共享更新python中的最后一个元素

时间:2015-08-22 17:26:45

标签: python theano

我有一个共享变量persistent_vis_chain,它正在被一个theano函数更新,它从theano.scan中获取它的功能,但这不仅仅是故事背后的问题。

我的共享变量看起来像D = [image1,...,imageN],其中每个图像都是[x1,x2,...,x784]。

我想要做的是拍摄所有图像的平均值并将它们放入最后一张图像中。那就是我想要除了最后1之外的每个图像中的所有值,这将导致[s1,s2,...,s784]然后我想设置imageN = [s1 / len(D),s2 / len (d),... S784 / LEN(d)]

所以我的问题是我不知道如何使用theano.shared做这个,可能是我对theano函数的理解并用符号变量进行这个计算。任何帮助将不胜感激。

1 个答案:

答案 0 :(得分:0)

如果您有N张图片,那么每个形状28x28=784可能是您的共享变量的形状(N,28,28)(N,784)?这种方法应该适用于任何一种形状。

鉴于D是包含图像数据的共享变量。如果您想获得平均图像,那么D.mean(keepdims=True)将象征性地为您提供。

目前还不清楚是否要将最终图像更改为等于平均图像(听起来像是一件奇怪的事情),或者如果您想再添加N+1'图像到共享变量。对于前者,你可以这样做:

D = theano.shared(load_D_data())
D_update_expression = do_something_with_scan_to_get_D_update_expression(D)
updates = [(D, T.concatenate(D_update_expression[:-1],
                             D_update_expression.mean(keepdims=True)))]
f = theano.function(..., updates=updates)

如果您想要执行后者(添加其他图片),请更改updates行,如下所示:

updates = [(D, T.concatenate(D_update_expression,
                             D_update_expression.mean(keepdims=True)))]

请注意,此代码仅供参考。它可能无法正常工作(例如,您可能需要弄乱axis=命令中的T.concatenate参数。

重点是你需要构造一个符号表达式来解释D的新值是什么样的。您希望它是扫描更新加上这个额外的平均值的组合。 T.concatenate允许您将这两个部分组合在一起。