如何使用数据集进行估算,将输入设置为600行以预测每行的数量

时间:2019-07-17 13:13:14

标签: tensorflow

我应该如何控制估计器的输入大小?我得到一个数据框67 000行,30列。我想使用每600行30列的数据来预测仅一行的每个输出。这是一个LStm回归问题。但是出于其他考虑。我使用了估算器。 DNNLinear CombinedRegressor模型。但是,如何控制600行数据预测一个数字还不知道。另外,估计器的线性部分是LSTM还是rnn形式?由于使用estimator程序包,我不知道其中的详细结构。我使用tfrecord和数据集输入数据。

例如,前600行用于预测第600行的结果。第二低至601行的数字用于预测第601行的结果

我想用每600行30列的数据来预测每一行的输出这是一个lstm回归的问题。但由于其他方面的考虑。我使用了estimator.DNNLinearCombinedRegressor模型。数据预测1个数字的预测。另外estimator的线性部分,是否是lstm或者rnn的形式??由于estimator封装,我并不清楚里面的详细结构 说白了就是想用前600行数据预测第600行的结果2到601行预测第601行的结果,怎么才能做到?

def read_data_input(tf_dir,batch_size):     打印(tf_dir)     raw_dataset = tf.data.TFRecordDataset(tf_dir)     def _parse_function(记录):         feature_description = {             'volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'Quote_asset_volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'Number_of_trades':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'Taker_buy_base_asset_volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'Taker_buy_quote_asset_volume':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             #'candle_begin_time_GMT8':_float_feature(df ['pred']。loc [i]),             '5day_press':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             '30day_press':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             '60day_press':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             '更改':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'change_day':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'day_ofyear':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             '系列':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'num_updowm':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'a':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'up_shaow_day':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'down_shaow_day':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'up_shaow_15min':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'down_shaow_15min':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'szzs':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             #'sema':_ float_feature(df ['sema']。loc [i]),             #'lema':_float_feature(df ['lema']。loc [i]),             'dif':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'dea':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'macd':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'K':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'D':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'J':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'KDJ_gold':tf.FixedLenFeature([],tf.string,default_value =''),             'low_rate':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'high_rate':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             '时间':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'time_next':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'RSI':tf.FixedLenFeature([],tf.float32,default_value = 0.0),             'pred':tf.FixedLenFeature((),tf.float32,)#这里是标签         }         解析= tf.parse_single_example(记录,feature_description)

    labels=parsed['pred']
    #kdj_gold = tf.reshape(parsed['KDJ_gold'],shape=[3,0])
    kdj_gold=tf.cast(parsed['KDJ_gold'],tf.string)




    return {'volume':parsed['volume'],
            'Quote_asset_volume': parsed['Quote_asset_volume'],
            'Number_of_trades': parsed['Number_of_trades'],
            'Taker_buy_base_asset_volume': parsed['Taker_buy_base_asset_volume'],
            'Taker_buy_quote_asset_volume': parsed['Taker_buy_quote_asset_volume'],
            '5day_press': parsed['5day_press'],
            '30day_press': parsed['30day_press'],
            '60day_press': parsed['60day_press'],
            'change': parsed['change'],
            'change_day': parsed['change_day'],
            'day_ofyear': parsed['day_ofyear'],
            'series': parsed['series'],
            'num_updowm': parsed['num_updowm'],
            'a': parsed['a'],
            'up_shaow_day': parsed['up_shaow_day'],
            'down_shaow_day': parsed['down_shaow_day'],
            'up_shaow_15min': parsed['up_shaow_15min'],
            'down_shaow_15min': parsed['down_shaow_15min'],
            'szzs': parsed['szzs'],
            'dif': parsed['dif'],
            'dea': parsed['dea'],
            'macd': parsed['macd'],
            'K': parsed['K'],
            'D': parsed['D'],
            'J': parsed['J'],
            'KDJ_gold': kdj_gold,
            'low_rate': parsed['low_rate'],
            'high_rate': parsed['high_rate'],
            'time': parsed['time'],
            'time_next': parsed['time_next'],
            'RSI': parsed['RSI'],
            },labels

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset= parsed_dataset.repeat()
parsed_dataset= parsed_dataset.batch(batch_size)#n_step)
parsed_dataset = parsed_dataset.window(5)
print(parsed_dataset.output_shapes)
return parsed_dataset

我尝试了很多方法,但都失败了。如何更改代码,以便将输入设置为600行以预测行数?

0 个答案:

没有答案