将keras数据生成器转换为张量流估计器

时间:2018-06-27 22:51:20

标签: python tensorflow keras tensorflow-datasets tensorflow-estimator

我想使用具有所有训练数据集的pandas数据框df_train和具有批次起始索引和终止索引的numpy数组arr_train创建自定义批次。我想基于开始和结束索引生成批次。

例如我的df_train看起来像

index col1 col2 col3
0     100   121  A
1     101   211  A
2     102   213  B

我的arr_train像arr_train = [[0 1] [2 2]]一样

这意味着我的第i个批次将是df_train.loc[arr_train[i,0]:arr_train[i,1],:]

我知道如何使用keras来做到这一点。但是我想将我的keras模型转换为tensorflow.estimator,所以我需要将我的batch_generator转换为tensorflow.estimator

这是我的keras data_generator

import numpy as np
from tensorflow import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, df,arr_df,predictors,response,weight=None, shuffle = False, training=False):
        'Initialization'
        self.arr_df = arr_df
        self.df = df
        self.shuffle = shuffle
        self.predictors = predictors
        self.response = response
        self.weight = weight
        self.traning = training
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return self.arr_df.shape[0]

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of batch
        # idx = self.indexes[self.arr_df[index,0]:self.arr_df[index,1]]

        # Generate data
        a1 = self.arr_df[index,0]
        a2 = self.arr_df[index, 1]
        X1 = self.df.loc[a1:a2,self.predictors].as_matrix().reshape((1,-1,len(self.predictors)))
        if (self.traning) & (self.weight !=None):
            y = self.df.loc[a1:a2,self.response].as_matrix().reshape((1,-1,1))
            w = self.df.loc[a1:a2, self.weight].unique()
            return X1,y,w
        elif (self.traning) & (self.weight ==None):
            y = self.df.loc[a1:a2, self.response].as_matrix().reshape((1,-1, 1))
            return X1,y
        return X1

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(self.arr_df.shape[0])
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

0 个答案:

没有答案