我应该如何控制估计器的输入大小?我得到一个数据框67 000行,30列。我想使用每600行30列的数据来预测仅一行的每个输出。这是一个LStm回归问题。但是出于其他考虑。我使用了估算器。 DNNLinear CombinedRegressor模型。但是,如何控制600行数据预测一个数字还不知道。另外,估计器的线性部分是LSTM还是rnn形式?由于使用estimator程序包,我不知道其中的详细结构。我使用tfrecord和数据集输入数据。
例如,前600行用于预测第600行的结果。第二低至601行的数字用于预测第601行的结果
我想用每600行30列的数据来预测每一行的输出这是一个lstm回归的问题。但由于其他方面的考虑。我使用了estimator.DNNLinearCombinedRegressor模型。数据预测1个数字的预测。另外estimator的线性部分,是否是lstm或者rnn的形式??由于estimator封装,我并不清楚里面的详细结构 说白了就是想用前600行数据预测第600行的结果2到601行预测第601行的结果,怎么才能做到?
def read_data_input(tf_dir,batch_size): 打印(tf_dir) raw_dataset = tf.data.TFRecordDataset(tf_dir) def _parse_function(记录): feature_description = { 'volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'Quote_asset_volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'Number_of_trades':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'Taker_buy_base_asset_volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'Taker_buy_quote_asset_volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0), #'candle_begin_time_GMT8':_float_feature(df ['pred']。loc [i]), '5day_press':tf.FixedLenFeature([],tf.float32,default_value = 0.0), '30day_press':tf.FixedLenFeature([],tf.float32,default_value = 0.0), '60day_press':tf.FixedLenFeature([],tf.float32,default_value = 0.0), '更改':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'change_day':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'day_ofyear':tf.FixedLenFeature([],tf.float32,default_value = 0.0), '系列':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'num_updowm':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'a':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'up_shaow_day':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'down_shaow_day':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'up_shaow_15min':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'down_shaow_15min':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'szzs':tf.FixedLenFeature([],tf.float32,default_value = 0.0), #'sema':_ float_feature(df ['sema']。loc [i]), #'lema':_float_feature(df ['lema']。loc [i]), 'dif':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'dea':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'macd':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'K':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'D':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'J':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'KDJ_gold':tf.FixedLenFeature([],tf.string,default_value =''), 'low_rate':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'high_rate':tf.FixedLenFeature([],tf.float32,default_value = 0.0), '时间':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'time_next':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'RSI':tf.FixedLenFeature([],tf.float32,default_value = 0.0), 'pred':tf.FixedLenFeature((),tf.float32,)#这里是标签 } 解析= tf.parse_single_example(记录,feature_description)
labels=parsed['pred']
#kdj_gold = tf.reshape(parsed['KDJ_gold'],shape=[3,0])
kdj_gold=tf.cast(parsed['KDJ_gold'],tf.string)
return {'volume':parsed['volume'],
'Quote_asset_volume': parsed['Quote_asset_volume'],
'Number_of_trades': parsed['Number_of_trades'],
'Taker_buy_base_asset_volume': parsed['Taker_buy_base_asset_volume'],
'Taker_buy_quote_asset_volume': parsed['Taker_buy_quote_asset_volume'],
'5day_press': parsed['5day_press'],
'30day_press': parsed['30day_press'],
'60day_press': parsed['60day_press'],
'change': parsed['change'],
'change_day': parsed['change_day'],
'day_ofyear': parsed['day_ofyear'],
'series': parsed['series'],
'num_updowm': parsed['num_updowm'],
'a': parsed['a'],
'up_shaow_day': parsed['up_shaow_day'],
'down_shaow_day': parsed['down_shaow_day'],
'up_shaow_15min': parsed['up_shaow_15min'],
'down_shaow_15min': parsed['down_shaow_15min'],
'szzs': parsed['szzs'],
'dif': parsed['dif'],
'dea': parsed['dea'],
'macd': parsed['macd'],
'K': parsed['K'],
'D': parsed['D'],
'J': parsed['J'],
'KDJ_gold': kdj_gold,
'low_rate': parsed['low_rate'],
'high_rate': parsed['high_rate'],
'time': parsed['time'],
'time_next': parsed['time_next'],
'RSI': parsed['RSI'],
},labels
parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset= parsed_dataset.repeat()
parsed_dataset= parsed_dataset.batch(batch_size)#n_step)
parsed_dataset = parsed_dataset.window(5)
print(parsed_dataset.output_shapes)
return parsed_dataset
我尝试了很多方法,但都失败了。如何更改代码,以便将输入设置为600行以预测行数?