使用熊猫数据框设置Keras模型

时间:2018-06-29 16:07:12

标签: python pandas numpy keras

这是我第一次使用python和Keras进行机器学习,我习惯于使用MATLAB。基本上,我有一个实木复合地板,其中包含标签的一列和文本的另一列。我使用GloVe嵌入获取文本并将其向量化,因此在所有这些之后,我剩下2列:向量化,其中每个numpy数组都有一个4000个数字的ndarray;和标签列。然后,我尝试使用此向量化列作为模型输入,但这是我遇到的问题。

pd_df.head(1) #pd_df is my dataframe

输出:

    vectorized  label
0   [-0.10767000168561935, 0.11052999645471573, 0....   0

然后我拆分数据并转换为ndarrays

from sklearn.model_selection import train_test_split

train, test = train_test_split(pd_df, test_size=0.3)

trainLabels = train.as_matrix(columns=['label'])
train = train.as_matrix(columns=['vectorized'])

testLabels = test.as_matrix(columns=['label'])
test = test.as_matrix(columns=['vectorized'])

然后我检查数据的形状:

train.shape
(410750, 1)

这是我对numpy缺乏了解的地方,因为这个大小对我来说没有意义。似乎应该是(410750,4000),因为每个元素都是4000个项目的ndarray。

此后,我设置了模型:

from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import SGD
from keras.losses import binary_crossentropy
from keras.metrics import binary_accuracy

inputs = Input(shape=(4000,))

x = Dense(units=2000, activation='relu')(inputs)
x = Dense(units=500, activation='relu')(x)
output = Dense(units=2, activation='softmax')(x)

model = Model(inputs=inputs, outputs=output)
model.compile(optimizer=SGD(), loss=binary_crossentropy, metrics=['accuracy'])
model.fit(train, 
          trainLabels, 
          epochs=50,
          batch_size=50)

然后我不断收到错误消息:

ValueError: Error when checking input: expected input_13 to have shape (4000,) but got array with shape (1,)

就像我说的那样,我是python世界中机器学习的新手,所以任何帮助都将是惊人的。

谢谢您的帮助。

2 个答案:

答案 0 :(得分:0)

您的训练数据只有1个维度,而您在输入中指定了4000个维度。此外,如果要使用经过预训练的词嵌入(例如GloVe),则应使用嵌入层。看看这个Keras博客: https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html

答案 1 :(得分:0)

要解决此问题,我必须解压缩我的数组数组。我选择执行此操作的方式是:

<?xml version = '1.0' encoding = 'UTF-8'?>

<definitions name="NotificationService_WSDL"
             targetNamespace="http://xmlns.oracle.com/ias/pcbpel/NotificationService"
             xmlns="http://schemas.xmlsoap.org/wsdl/"
             xmlns:xsd="http://www.w3.org/2001/XMLSchema"
             xmlns:soap="http://schemas.xmlsoap.org/wsdl/soap/"
             xmlns:tns="http://xmlns.oracle.com/ias/pcbpel/NotificationService" 
             xmlns:ns1="http://IWSDLDocument1.xsd"
             xmlns:java="http://schemas.xmlsoap.org/wsdl/java/"
             xmlns:format="http://schemas.xmlsoap.org/wsdl/formatbinding/"
             xmlns:plnk="http://schemas.xmlsoap.org/ws/2003/05/partner-link/" >
  <types>
    <schema targetNamespace="http://xmlns.oracle.com/ias/pcbpel/NotificationService1"
            xmlns="http://www.w3.org/2001/XMLSchema"
            xmlns:SOAP-ENC="http://schemas.xmlsoap.org/soap/encoding/">
      <import namespace="http://xmlns.oracle.com/ias/pcbpel/NotificationService" 
                        schemaLocation="NotificationService.xsd"/>
    </schema>
    <schema targetNamespace="http://xmlns.oracle.com/ias/pcbpel/NotificationService"
            xmlns="http://www.w3.org/2001/XMLSchema"
            xmlns:SOAP-ENC="http://schemas.xmlsoap.org/soap/encoding/">
      <!-- Element used to indicate the fault message -->
      <element name="faultMessage" type="xsd:string"/>
    </schema>
  </types>
  <message name="IMNotificationRequest">
    <part name="IMPayload" type="tns:IMPayloadType"/>
  </message>
  <message name="FaxNotificationRequest">
    <part name="FaxPayload" type="tns:FaxPayloadType"/>
  </message>
  <message name="VoiceNotificationRequest">
    <part name="VoicePayload" type="tns:VoicePayloadType"/>
  </message>
  <message name="SMSNotificationRequest">
    <part name="SMSPayload" type="tns:SMSPayloadType"/>
  </message>
  <message name="PagerNotificationRequest">
    <part name="PagerPayload" type="tns:PagerPayloadType"/>
  </message>
  <message name="EmailNotificationRequest">
    <part name="EmailPayload" type="tns:EmailPayloadType"/>
  </message>
  <message name="UserNotificationRequest">
    <part name="UserNotificationPayload" type="tns:UserNotificationPayloadType"/>
  </message>
  <message name="GroupNotificationRequest">
    <part name="GroupNotificationPayload" type="tns:GroupNotificationPayloadType"/>
  </message>

  <message name="ArrayOfGenericPayload">
    <part name="GenericPayloads" type="tns:GenericPayloadType"/>
  </message>

  <message name="Response">
    <part name="Response" type="tns:ResponseType"/>
  </message>

  <message name="ArrayOfResponse">
    <part name="Responses" type="tns:ArrayOfResponseType"/>
  </message>


  <message name="RegisterEmailReceiverRequest">
    <part name="AccountName" type="xsd:string"/>
  </message>

  <message name="UnregisterEmailReceiverRequest">
    <part name="AccountName" type="xsd:string"/>
  </message>

  <message name="NotificationServiceErrorMessage">
    <part name="faultInfo" element="tns:faultMessage"/>
  </message>

  <portType name="NotificationService">
    <operation name="sendIMNotification">
      <input name="IMNotificationRequest" message="tns:IMNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <operation name="sendFaxNotification">
      <input name="FaxNotificationRequest" message="tns:FaxNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <operation name="sendVoiceNotification">
      <input name="VoiceNotificationRequest" message="tns:VoiceNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <operation name="sendSMSNotification">
      <input name="SMSNotificationRequest" message="tns:SMSNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <operation name="sendPagerNotification">
      <input name="PagerNotificationRequest" message="tns:PagerNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <operation name="sendEmailNotification">
      <input name="EmailNotificationRequest" message="tns:EmailNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <operation name="SendNotificationResult">
      <input name="EmailNotificationRequest" message="tns:EmailNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <!--<operation name="sendNotifications">
      <input message="tns:ArrayOfGenericPayload"/>
      <output message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>-->
    <operation name="sendNotificationToUser">
      <input name="UserNotificationRequest"  message="tns:UserNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
    <operation name="sendNotificationToGroup">
      <input name="GroupNotificationRequest" message="tns:GroupNotificationRequest"/>
      <output name="ArrayOfResponse" message="tns:ArrayOfResponse"/>
      <fault name="NotificationServiceFault" message="tns:NotificationServiceErrorMessage"/>
    </operation>
  </portType>

  <binding name="JavaBinding" type="tns:NotificationService">
      <java:binding/>
      <format:typeMapping encoding="Java" style="Java">
          <format:typeMap typeName="tns:IMPayloadType"
                                    formatType="oracle.bpel.services.notification.payload.IMPayloadType" />
          <format:typeMap typeName="tns:FaxPayloadType"
                                    formatType="oracle.bpel.services.notification.payload.FaxPayloadType" />
          <format:typeMap typeName="tns:VoicePayloadType"
                                    formatType="oracle.bpel.services.notification.payload.VoicePayloadType" />
          <format:typeMap typeName="tns:SMSPayloadType"
                                    formatType="oracle.bpel.services.notification.payload.SMSPayloadType" />
          <format:typeMap typeName="tns:PagerPayloadType"
                                    formatType="oracle.bpel.services.notification.payload.PagerPayloadType" />
          <format:typeMap typeName="tns:EmailPayloadType"
                                    formatType="oracle.bpel.services.notification.payload.EmailPayloadType" />
          <format:typeMap typeName="tns:UserNotificationPayloadType"
                                    formatType="oracle.bpel.services.notification.payload.UserNotificationPayloadType" />
          <format:typeMap typeName="tns:GroupNotificationPayloadType"
                                    formatType="oracle.bpel.services.notification.payload.GroupNotificationPayloadType" />
          <format:typeMap typeName="tns:ResponseType"
                                    formatType="oracle.bpel.services.notification.payload.ResponseType" />
          <format:typeMap typeName="tns:ContentType"
                                    formatType="oracle.bpel.services.notification.payload.ContentType" />
          <format:typeMap typeName="tns:ArrayOfEmailHeaderType"
                                    formatType="oracle.bpel.services.notification.payload.ArrayOfEmailHeaderType" />
          <format:typeMap typeName="tns:EmailHeaderType"
                                    formatType="oracle.bpel.services.notification.payload.EmailHeaderType" />
      </format:typeMapping>
      <operation name="sendEmailNotification">
            <java:operation methodName="sendEmailNotification" methodType="static" parameterOrder="EmailPayload"/>
                <input name="EmailNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>

      <operation name="sendVoiceNotification">
            <java:operation methodName="sendVoiceNotification" methodType="static" parameterOrder="VoicePayload"/>
                <input name="VoiceNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>

      <operation name="sendSMSNotification">
            <java:operation methodName="sendSMSNotification" methodType="static" parameterOrder="SMSPayload"/>
                <input name="SMSNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>

      <operation name="sendPagerNotification">
            <java:operation methodName="sendPagerNotification" methodType="static" parameterOrder="PagerPayload"/>
                <input name="PagerNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>

      <operation name="sendIMNotification">
            <java:operation methodName="sendIMNotification" methodType="static" parameterOrder="IMPayload"/>
                <input name="IMNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>

      <operation name="sendFaxNotification">
            <java:operation methodName="sendFaxNotification" methodType="static" parameterOrder="FaxPayload"/>
                <input name="FaxNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>
      <operation name="sendNotificationToUser">
            <java:operation methodName="sendNotificationToUser" methodType="static" parameterOrder="UserNotificationPayload"/>
                <input name="UserNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>
      <operation name="sendNotificationToGroup">
            <java:operation methodName="sendNotificationToGroup" methodType="static" parameterOrder="GroupNotificationPayload"/>
                <input name="GroupNotificationRequest" />
                <output name="ArrayOfResponse" />
       </operation>

  </binding>

  <service name="NotificationService">
      <port name="JavaPort" binding="tns:JavaBinding" >
          <java:address className="oracle.bpel.services.notification.NotificationService" />
      </port>

  </service>

    <plnk:partnerLinkType name="NotificationServiceLink">
        <plnk:role name="NotificationServiceProvider">
            <plnk:portType name="tns:NotificationService" />
        </plnk:role>
    </plnk:partnerLinkType>

</definitions>