将keras NN映射到R中的数据列表

时间:2020-01-22 15:39:20

标签: r purrr furrr

我正在尝试找出对每个列表应用keras模型的正确方法。我使用了iris数据集并创建了4个列表,目标是正确预测versicolorvirginica(由于需要二进制分类模型,因此省略了setosa)。 / p>

data(iris)
iris <- iris %>% 
  mutate(
    splt = sample(4, size = nrow(.), replace = TRUE),
    binary = case_when(
      Species == "versicolor" ~ 0,
      Species == "virginica" ~ 1
    )
  ) %>%  
  filter(Species != "setosa") %>% 
  split(., .$splt)

iris_x_train <- iris %>% 
  map(., ~select(., Sepal.Length, Sepal.Width, Petal.Length, Petal.Width) %>% 
        as.matrix())

iris_y_train <- iris %>% 
  map(., ~select(., binary) %>% 
        to_categorical(2))

NN_model <- keras_model_sequential() %>% 
  layer_dense(units = 4, activation = 'relu', input_shape = 4) %>% 
  layer_dense(units = 2, activation = 'softmax')

NN_model %>% 
  summary

NN_model %>% 
  compile(
    loss = 'binary_crossentropy',
    optimizer_sgd(lr = 0.01, momentum = 0.9),
    metrics = c('accuracy')
  )

我的问题出现在这里。当我应用以下代码时:

NN_model %>%
  future_map(., ~future_map2(
    .x = iris_x_train,
    .y = iris_y_train,
    ~fit(
      x = .x,
      y = .y,
      epochs = 5,
      batch_size = 20,
      validation_split = 0
    )
  )
  )

我收到此错误:

py_get_item_impl(x,key,FALSE)中的错误:TypeError:“顺序” 对象不支持索引

当我应用此代码时:

NN_model %>%
  future_map2(
    .x = iris_x_train,
    .y = iris_y_train,
    ~fit(
      x = .x,
      y = .y,
      epochs = 5,
      batch_size = 20,
      validation_split = 0
      )
    )

我收到此错误:

〜fit(x = .x,y = .y,时期= 5,batch_size = 20,validation_split = 0)py_call_impl(可调用, 点($ args),点($ keywords):评估错误:无法转换R 对象为Python类型。

如何将keras模型映射到4个数据集中?

library(keras)
library(tensorflow)
library(furrr)
library(purrr)

以下内容适用于第一个列表:

NN_model %>% 
  fit(
    x = iris_x_train[[1]],
    y = iris_y_train[[1]],
    epochs = 50,
    batch_size = 20,
    validation_split = 0
  )

编辑:我似乎已经解决了。

NN_model放入fit()函数中似乎可以正常工作。

future_map2(
    .x = iris_x_train,
    .y = iris_y_train,
    ~fit(NN_model,
      .x,
      .y,
      epochs = 5,
      batch_size = 20,
      validation_split = 0
    )
  )

0 个答案:

没有答案