使用tf keras api定义图层的最佳方法?

时间:2018-04-26 17:45:16

标签: tensorflow

tensorflow.keras api在创建图层引用时没有工作,任何其他创建图层引用的方法? 代码: 层= keras.layers

错误消息:NameError:未定义名称“leyer”

此处粘贴完整代码......

import tensorflow as tf   
from tensorflow import keras   
import pandas as pd   
from sklearn.model_selection import KFold   
from sklearn.model_selection import cross_val_score   
from sklearn.preprocessing import LabelEncoder   
import numpy as np   

#makin seed values   
seed=7   
np.random.seed(seed)   

#setting up the dataset for training    
dataframe=pd.read_csv("../datasets/iris.csv",header=None)   
data=dataframe.values   
input_x = data[:,0:4]   
true_y = data[:,4]    

#Encoding the true_y data to one hot encoding   
le=LabelEncoder()   
le.fit(true_y)    
y_encoded = le.transform(true_y)    
y_encoded = keras.utils.to_categorical(y_encoded,num_classes=3)    

# creating the model    
def base_fun():    
    layer=keras.layers     
    model = keras.models.Sequential()
    model.add(layer.Dense(4,input_dim=4,kernel_initializer='normal',activation='relu'))   
    model.add(leyer.Dense(3, kernel_initializer='normal', activation='relu'))     

estimator=keras.wrappers.scikit_learn.KerasClassifier(build_fn=base_fun,epochs=20,batch_size=10)     
kfold = KFold(n_splits=10, shuffle=True, random_state=seed)    
result = cross_val_score(estimator, input_x, y_encoded,cv=kfold)    

print("Accuracy : %.2%% (%.2%%)" %(result.mean()*100, result.std()*100))     

1 个答案:

答案 0 :(得分:0)

好吧,这一行:

model.add(leyer.Dnese(3, kernel_initializer='normal', activation='relu')) 

有两个拼写错误,即leyer应为layerDnese应为Dense

model.add(layer.Dense(3, kernel_initializer='normal', activation='relu'))

根据您的评论,此行也会导致错误:

estimator = keras.wrappers.scikit_learn.KerasClassifier( build_fn = base_fun, epochs = 20, batch_size = 10 )

来自Keras Scikit documentation

  

build_fn应构造,编译并返回一个Keras模型,然后用于拟合/预测。

但你的功能base_fun()不会返回任何内容。在base_fun()

的末尾添加此行
return model

根据您的评论,最后print行可以更改为此(我不知道%格式,我通常使用下面的语法):

print( "Accuracy : {:.2%} ({:.2%})".format( result.mean(), result.std() ) )