我正在写一个类似于他们的MNIST LSTM示例代码的张量流程序。我正在构建我的数据文件,而我无法构建对象实例。
本质上,流程是:定义一个empyty实例data_sets = DataSet()
,然后构建对象data_sets.train = DataSet(arg1, arg2...)
和data_sets.test = DataSet(arg1, arg2...)
,依此类推
我在尝试构建data_sets.train = DataSet(arg1, arg2...)
MNIST代码如下所示:
class DataSet(object):
def __init__(self, images, labels, fake_data=False, one_hot=False,
dtype=tf.float32):
"""Construct a DataSet.
one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
`[0, 1]`.
"""
dtype = tf.as_dtype(dtype).base_dtype
#pdb.set_trace()
if dtype not in (tf.uint8, tf.float32):
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
dtype)
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
else:
pdb.set_trace()
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
if dtype == tf.float32:
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
然后在同一个文件中,他们有一个函数定义一个没有参数的实例(在pass
之后),构建数据集(我已经离开了那个部分),然后构建了对象{{1} } data_set
- 每次他们再次调用类构造函数时,这次它们包含参数。如下图所示
data_set.train, data_set.validation, and data_set.test
我基本上构建了完全相同的东西,但使用不同的数据集
这是我的班级定义(忽略标签,复制和粘贴混乱的标识 - 我不认为缩进是问题)
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
class DataSets(object):
pass
pdb.set_trace()
data_sets = DataSets()
...(build dataset)...
data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
data_sets.validation = DataSet(validation_images, validation_labels,
dtype=dtype)
data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
pdb.set_trace()
return data_sets
然后我用以下方法构建对象:
class ScrollData(object):
def __init__(self, images, labels, dtype=tf.float32):
dtype = tf.as_dtype(dtype).base_dtype
if dtype not in (tf.float64, tf.float32):
raise TypeError('Invalid image dtype %r, expected float64 or float32' %
dtype)
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
pdb.set_trace()
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
if dtype == tf.float32:
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
我收到以下错误:
def read_data(data_dir):
dtype=tf.float32
VALIDATION_SIZE = 1
TEST_SIZE = 1
class ScrollData(object):
pass
data_sets = ScrollData()
...(build dataset)...
data_sets.train = ScrollData(train_images, train_labels, dtype=tf.float32)
data_sets.validation = ScrollData(validation_images, validation_labels, dtype=tf.float32)
data_sets.test = ScrollData(testtest_images, test_labels, dtype=tf.float32)
return data_sets
答案 0 :(得分:2)
使用不带构造函数参数的类覆盖ScrollData
函数中的read_data
。
保留此重新定义,并为第一个调用添加参数或在构造函数中定义标准值