我有一个面板数据集,我想对其进行长期短期记忆(LSTM)。数据集来自postgreSQL数据库。我的数据结构类似于以下内容:
因此,我的时间步长是4。这是多对多LSTM,我的输入和输出都是序列。输入将具有形状[Batch_size, 4, 23]
,而输出将具有形状[Batch_size, 4, 2]
(我是单编码)。
我正在使用Python生成器来获取行。我要提取许多(fetchmany
,其中number_of_records为4,因为它对应于一个特定的人。
class it_try:
import passwords_and_paths
import psycopg2
def __init__(self, sql, number_of_records):
self.sql = sql
self.number_of_records = number_of_records
self.pgConnectString = "host='/var/run/postgresql' port='{}' dbname='{}' user='{}' password='{}'".format(it_try.passwords_and_paths.database['port'],
it_try.passwords_and_paths.database['name'],
it_try.passwords_and_paths.database['user'],
it_try.passwords_and_paths.database['pass'])
self.pgConnection=psycopg2.connect(self.pgConnectString)
self.pgCursor = self.pgConnection.cursor(name='fetch_large_result')
self.pgCursor.execute(self.sql)
def __iter__(self):
return self
def __next__(self):
row = self.pgCursor.fetchmany(self.number_of_records)
current_obs = []
for i in row:
current_obs.append(i)
features = np.array(current_obs)[:,3:26]
labels = np.array(current_obs)[:,-1].astype(int)
return features, labels
def __del__(self):
self.pgCursor.close()
要素的形状为[4,23]
,标签的形状为[4,]
。然后,我使用Tensorflow的tf.data.Dataset.from_generator()
函数从生成器初始化数据集。形状和数据类型已正确定义,我在这里对标签进行了一次编码,每个呼叫分三个人。
generator = it_try(sql = 'SELECT * FROM public.basetable order by year, customer_id, quarter', number_of_records = 4)
train_dataset = tf.data.Dataset.from_generator(lambda: generator, (tf.float32, tf.int32), (tf.TensorShape([4,23]), tf.TensorShape([4,])))
train_dataset=train_dataset.map(lambda *x:(x[0], tf.cast(tf.one_hot(x[1],2),tf.int32)))
train_dataset = train_dataset.batch(3)
输出为<BatchDataset shapes: ((?, 4, 23), (?, 4, 2)), types: (tf.float32, tf.int32)>
。到目前为止一切顺利。
我创建了迭代器并将其初始化,并且可以成功打印批次(在此示例中为2个批次)。
iterator = tf.data.Iterator.from_structure(train_dataset.output_types,
train_dataset.output_shapes)
X, y = iterator.get_next()
training_init_op = iterator.make_initializer(train_dataset)
with tf.Session() as sess:
sess.run(training_init_op)
for batch in range(2):
print(sess.run([X,y]))
但是,当我想多次传递训练数据时(在此示例中,时期数为2),我得到了一个错误,这当然是因为我无法重置Python和Tensorflow迭代器。
with tf.Session() as sess:
for epoch in range(2):
sess.run(training_init_op)
for batch in range(2):
print(sess.run([X,y]))
它将第一个纪元打印得很好,但是当它是第二个纪元时,我得到了错误。
---------------------------------------------------------------------------
UnknownError Traceback (most recent call last)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1321 try:
-> 1322 return fn(*args)
1323 except errors.OpError as e:
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1306 return self._call_tf_sessionrun(
-> 1307 options, feed_dict, fetch_list, target_list, run_metadata)
1308
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
1408 self._session, options, feed_dict, fetch_list, target_list,
-> 1409 run_metadata)
1410 else:
UnknownError: IndexError: too many indices for array
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/script_ops.py", line 158, in __call__
ret = func(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 410, in generator_py_func
values = next(generator_state.get_iterator(iterator_id))
File "<ipython-input-64-e6c5163f3adc>", line 26, in __next__
features = np.array(current_obs)[:,3:26]
IndexError: too many indices for array
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_INT32], token="pyfunc_46"](arg0)]]
[[Node: IteratorGetNext_23 = IteratorGetNext[output_shapes=[[?,4,23], [?,4,2]], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_23)]]
During handling of the above exception, another exception occurred:
UnknownError Traceback (most recent call last)
<ipython-input-67-213eeaa1c283> in <module>()
7 sess.run(training_init_op)
8 for i in range(2):
----> 9 print(sess.run([X,y]))
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
898 try:
899 result = self._run(None, fetches, feed_dict, options_ptr,
--> 900 run_metadata_ptr)
901 if run_metadata:
902 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1133 if final_fetches or final_targets or (handle and feed_dict_tensor):
1134 results = self._do_run(handle, final_targets, final_fetches,
-> 1135 feed_dict_tensor, options, run_metadata)
1136 else:
1137 results = []
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1314 if handle is None:
1315 return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1316 run_metadata)
1317 else:
1318 return self._do_call(_prun_fn, handle, feeds, fetches)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1333 except KeyError:
1334 pass
-> 1335 raise type(e)(node_def, op, message)
1336
1337 def _extend_graph(self):
UnknownError: IndexError: too many indices for array
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/script_ops.py", line 158, in __call__
ret = func(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 410, in generator_py_func
values = next(generator_state.get_iterator(iterator_id))
File "<ipython-input-64-e6c5163f3adc>", line 26, in __next__
features = np.array(current_obs)[:,3:26]
IndexError: too many indices for array
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_INT32], token="pyfunc_46"](arg0)]]
[[Node: IteratorGetNext_23 = IteratorGetNext[output_shapes=[[?,4,23], [?,4,2]], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_23)]]
我尝试.repeat(2)
无济于事。
有人可以帮助我吗?当我使用python迭代器(数据来自数据库)时,如何运行纪元?