如何更新代码以使用Keras功能模型

时间:2019-03-01 20:52:15

标签: python keras layer

我需要升级在github上找到的一些代码,以便可以正确地构建模型。我需要连接一些层(在Keras的旧版本中,这是通过Merge(method ='concat'完成的),但是现在我需要使用连接函数。为此,我需要使用功能模型。

例如:

model1 = Sequential()
model1.add(Dense(300, input_dim=40, activation='relu', name='layer_1'))

将被“更新为”:

model1_in = Input(shape=(27, 27, 1))
model1_out = Dense(300, input_dim=40, activation='relu', name='layer_1')(model1_in)
model1 = Model(model1_in, model1_out)

我需要更新的代码如下:

embed_quarter_hour = Sequential()
embed_quarter_hour.add(Embedding(metadata['n_quarter_hours'], embedding_dim, input_length=1))
embed_quarter_hour.add(Reshape((embedding_dim,)))

要更新的整个代码:

# Arbitrary dimension for all embeddings
embedding_dim = 10

# Quarter hour of the day embedding
embed_quarter_hour = Sequential()
embed_quarter_hour.add(Embedding(metadata['n_quarter_hours'], embedding_dim, input_length=1))
embed_quarter_hour.add(Reshape((embedding_dim,)))

#Quarter hour of the day embedding

# Day of the week embedding
embed_day_of_week = Sequential()
embed_day_of_week.add(Embedding(metadata['n_days_per_week'], embedding_dim, input_length=1))
embed_day_of_week.add(Reshape((embedding_dim,)))


# Week of the year embedding
embed_week_of_year = Sequential()
embed_week_of_year.add(Embedding(metadata['n_weeks_per_year'], embedding_dim, input_length=1))
embed_week_of_year.add(Reshape((embedding_dim,)))


# Client ID embedding
embed_client_ids = Sequential()
embed_client_ids.add(Embedding(metadata['n_client_ids'], embedding_dim, input_length=1))
embed_client_ids.add(Reshape((embedding_dim,)))


# Taxi ID embedding
embed_taxi_ids = Sequential()
embed_taxi_ids.add(Embedding(metadata['n_taxi_ids'], embedding_dim, input_length=1))
embed_taxi_ids.add(Reshape((embedding_dim,)))


# Taxi stand ID embedding
embed_stand_ids = Sequential()
embed_stand_ids.add(Embedding(metadata['n_stand_ids'], embedding_dim, input_length=1))
embed_stand_ids.add(Reshape((embedding_dim,)))



# GPS coordinates (5 first lat/long and 5 latest lat/long, therefore 20 values)
coords = Sequential()
coords.add(Dense(1, input_dim=20, init='normal'))


model = Sequential()
model.add(Merge([
            embed_quarter_hour,
            embed_day_of_week,
            embed_week_of_year,
            embed_client_ids,
            embed_taxi_ids,
            embed_stand_ids,
            coords
        ]),method='concat')

# Simple hidden layer
model.add(Dense(500))
model.add(Activation('relu'))

# Determine cluster probabilities using softmax
model.add(Dense(len(clusters)))
model.add(Activation('softmax'))

0 个答案:

没有答案