我正在尝试构建卷积变分自动编码器。为此,我需要展平和展平张量。但是,当我尝试在展平形状之前存储形状(以便在解码器部分将其重塑形状)时,遇到了以下错误
Fetch argument None has invalid type <type 'NoneType'\>
同时尝试重塑张量。我在下面包含了代码和错误。
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
%matplotlib inline
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)
input_shape = [28, 28]
def encoder(_input):
conv1 = tf.layers.conv2d(_input, 3, [5,5], strides=(2,2), activation=tf.nn.elu)
conv2 = tf.layers.conv2d(conv1, 6, [5,5], strides=(2, 2), activation=tf.nn.elu)
logits = tf.layers.conv2d(conv2, 2, [3, 3], activation=None)
flat = tf.layers.flatten(logits)
return flat
def decoder(flat):
shape = flat.get_shape().as_list()
return shape
x_train = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
enc = encoder(x_train)
dec = decoder(enc)
img = np.reshape(mnist.train.images[:2], (-1, 28, 28, 1))
img.shape
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
val = sess.run(dec, feed_dict={x_train: img})
TypeError Traceback (most recent call last)
<ipython-input-43-5f497519723a> in <module>()
3 with tf.Session() as sess:
4 sess.run(init)
----> 5 summary, val = sess.run([summary_op, dec], feed_dict={x_train: img})
6
/home/13mcpc11/anaconda2/envs/dl-py2-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
887 try:
888 result = self._run(None, fetches, feed_dict, options_ptr,
--> 889 run_metadata_ptr)
890 if run_metadata:
891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/home/13mcpc11/anaconda2/envs/dl-py2-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
1103 # Create a fetch handler to take care of the structure of fetches.
1104 fetch_handler = _FetchHandler(
-> 1105 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1106
1107 # Run request and get response.
/home/13mcpc11/anaconda2/envs/dl-py2-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, graph, fetches, feeds, feed_handles)
412 """
413 with graph.as_default():
--> 414 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
415 self._fetches = []
416 self._targets = []
/home/13mcpc11/anaconda2/envs/dl-py2-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
232 elif isinstance(fetch, (list, tuple)):
233 # NOTE(touts): This is also the code path for namedtuples.
--> 234 return _ListFetchMapper(fetch)
235 elif isinstance(fetch, dict):
236 return _DictFetchMapper(fetch)
/home/13mcpc11/anaconda2/envs/dl-py2-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, fetches)
339 """
340 self._fetch_type = type(fetches)
--> 341 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
342 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
343
/home/13mcpc11/anaconda2/envs/dl-py2-env/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
229 if fetch is None:
230 raise TypeError('Fetch argument %r has invalid type %r' %
--> 231 (fetch, type(fetch)))
232 elif isinstance(fetch, (list, tuple)):
233 # NOTE(touts): This is also the code path for namedtuples.
TypeError: Fetch argument None has invalid type <type 'NoneType'>