R中的python keras和keras之间的精度不同

时间:2018-04-18 03:48:44

标签: python r keras

我用R在ke中为R建立了一个图像分类模型。

准确率高达98%,而在python中准确度很高。

R的Keras版本是2.1.3,而python中的2.1.5

以下是R型号代码:

model=keras_model_sequential()
model=model %>% 
  layer_conv_2d(filters = 32,kernel_size = c(3,3),padding = 'same',input_shape = c(187,256,3),activation = 'elu')%>%
  layer_max_pooling_2d(pool_size = c(2,2)) %>%
  layer_dropout(.25) %>% layer_batch_normalization() %>%
  layer_conv_2d(filters = 64,kernel_size = c(3,3),padding = 'same',activation = 'relu') %>%
  layer_max_pooling_2d(pool_size = c(2,2)) %>%
  layer_dropout(.25) %>% layer_batch_normalization() %>% layer_flatten() %>%
  layer_dense(128,activation = 'relu') %>%
  layer_dropout(.25)%>%
  layer_batch_normalization() %>%
  layer_dense(6,activation = 'softmax')


model %>%compile(
  loss='categorical_crossentropy',
  optimizer='adam',
  metrics='accuracy'
)

我尝试在python中使用相同的输入数据重建相同的模型。

虽然,表现完全不同。精度甚至低于30%

因为R keras正在为run keras调用python。使用相同的模型架构,它们应该具有相似的性能。

我想知道这个问题是由preprocess引起的,但仍然显示我的python代码:

model=Sequential()
model.add(Conv2D(32,kernel_size=(3,3),activation='relu',input_shape=(187,256,3),padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(BatchNormalization())
model.add(Conv2D(64, (3, 3), activation='relu',padding='same'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.25))
model.add(BatchNormalization())
model.add(Dense(len(label[1]), activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

这是一个简单的分类。我和大多数教学一样。

找不到其他人遇到同样的问题。所以想问问它是如何发生以及如何解决的。 THX

1 个答案:

答案 0 :(得分:1)

这是一个巨大的差异因此,代码中可能存在错误或数据中出现意外情况,但在Keras中复制R Python的结果比看起来更困难在R侧设置种子是不够的。您应该使用set.seed代替use_session_with_seed,而tensorflowkeras的R库附带use_session_with_seed(..., disable_gpu=TRUE, disable_parallel_cpu=TRUE)。请注意,为了完全重现,您需要kerasformula。另请参阅stacktf文档。此外,这里是使用layer_dropout的github版本和公共数据集的example。另外,请注意接受seed作为参数的compiler-embeddable等功能。