我需要升级在github上找到的一些代码,以便可以正确地构建模型。我需要连接一些层(在Keras的旧版本中,这是通过Merge(method ='concat'完成的),但是现在我需要使用连接函数。为此,我需要使用功能模型。
例如:
model1 = Sequential()
model1.add(Dense(300, input_dim=40, activation='relu', name='layer_1'))
将被“更新为”:
model1_in = Input(shape=(27, 27, 1))
model1_out = Dense(300, input_dim=40, activation='relu', name='layer_1')(model1_in)
model1 = Model(model1_in, model1_out)
我需要更新的代码如下:
embed_quarter_hour = Sequential()
embed_quarter_hour.add(Embedding(metadata['n_quarter_hours'], embedding_dim, input_length=1))
embed_quarter_hour.add(Reshape((embedding_dim,)))
要更新的整个代码:
# Arbitrary dimension for all embeddings
embedding_dim = 10
# Quarter hour of the day embedding
embed_quarter_hour = Sequential()
embed_quarter_hour.add(Embedding(metadata['n_quarter_hours'], embedding_dim, input_length=1))
embed_quarter_hour.add(Reshape((embedding_dim,)))
#Quarter hour of the day embedding
# Day of the week embedding
embed_day_of_week = Sequential()
embed_day_of_week.add(Embedding(metadata['n_days_per_week'], embedding_dim, input_length=1))
embed_day_of_week.add(Reshape((embedding_dim,)))
# Week of the year embedding
embed_week_of_year = Sequential()
embed_week_of_year.add(Embedding(metadata['n_weeks_per_year'], embedding_dim, input_length=1))
embed_week_of_year.add(Reshape((embedding_dim,)))
# Client ID embedding
embed_client_ids = Sequential()
embed_client_ids.add(Embedding(metadata['n_client_ids'], embedding_dim, input_length=1))
embed_client_ids.add(Reshape((embedding_dim,)))
# Taxi ID embedding
embed_taxi_ids = Sequential()
embed_taxi_ids.add(Embedding(metadata['n_taxi_ids'], embedding_dim, input_length=1))
embed_taxi_ids.add(Reshape((embedding_dim,)))
# Taxi stand ID embedding
embed_stand_ids = Sequential()
embed_stand_ids.add(Embedding(metadata['n_stand_ids'], embedding_dim, input_length=1))
embed_stand_ids.add(Reshape((embedding_dim,)))
# GPS coordinates (5 first lat/long and 5 latest lat/long, therefore 20 values)
coords = Sequential()
coords.add(Dense(1, input_dim=20, init='normal'))
model = Sequential()
model.add(Merge([
embed_quarter_hour,
embed_day_of_week,
embed_week_of_year,
embed_client_ids,
embed_taxi_ids,
embed_stand_ids,
coords
]),method='concat')
# Simple hidden layer
model.add(Dense(500))
model.add(Activation('relu'))
# Determine cluster probabilities using softmax
model.add(Dense(len(clusters)))
model.add(Activation('softmax'))