如何从keras API模型获取tf.gradients?

时间:2018-08-04 11:05:20

标签: python tensorflow neural-network keras deep-learning

我想知道如何从使用Keras API构建的模型中获取tf.gradients

import Tensorflow as tf
from tensorflow import keras
from sklearn.datasets.samples_generator import make_blobs

# Create the model
inputs = keras.Input(shape=(2,))
x = keras.layers.Dense(12, activation='relu')(inputs)
x = keras.layers.Dense(8, activation='relu')(x)
predictions = keras.layers.Dense(3, activation='softmax')(x)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
          loss='categorical_crossentropy',
          metrics=['accuracy'])

# Generate random data
X, y = make_blobs(n_samples=1000, centers=3, n_features=2)
labels = keras.utils.to_categorical(y, num_classes=3)

# Compute the gradients wrt inputs
y_true = tf.convert_to_tensor(labels)
y_pred = tf.convert_to_tensor(np.round(model.predict(X)))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
grads = tf.gradients(model.loss_functions[0](y_true, y_pred), 
                      model.inputs[0])
sess.run(grads, input_dict={model.inputs[0]: X, model.outputs: y})

以上首次尝试:我的毕业证书是None。在下面进行第二次尝试:

sess.run(grads, input_dict={model.inputs: X, model.outputs: y })

我收到以下错误:

TypeError: unhashable type: 'list'

1 个答案:

答案 0 :(得分:0)

我认为您在使用Keras时不应直接与Tensorflow创建新会话。相反,最好使用Keras隐式创建的会话:

CREATE TABLE SAMPLE_TABLE AS
  select '0100' as id_num,
  XMLTYPE('<TRX>
  <DATA>
      <Request APIType="null">
          <SubscriberIdsInfo>
              <ExternalId>
                  <ExternalId>0100</ExternalId>
              </ExternalId>
              <SubscriberId>
                  <SubscrNumber/>
              </SubscriberId>
          </SubscriberIdsInfo>
          <UpdateServices>
              <Soc>ABC</Soc>
              <ServiceAgreementSequenceNo/>
              <DealerCode/>
              <DeployMode/>
              <EffectiveDate>2019-10-16T00:00:00</EffectiveDate>
              <ExpirationDate/>
              <OfferInstanceId/>
          </UpdateServices>
          <AddServices>
              <Soc>ABC1</Soc>
              <ServiceAgreementSequenceNo/>
              <DealerCode/>
              <DeployMode/>
              <EffectiveDate>2018-10-16T00:00:00</EffectiveDate>
              <ExpirationDate/>
              <OfferInstanceId/>
          </AddServices>
          <RemoveServices>
              <Soc>ABC2</Soc>
              <ServiceAgreementSequenceNo/>
              <DealerCode/>
              <DeployMode/>
              <EffectiveDate>2017-10-16T00:00:00</EffectiveDate>
              <ExpirationDate/>
              <OfferInstanceId/>
          </RemoveServices>
          <SubParameters>
              <Name>PoolID</Name>
              <Values>POOL0100</Values>
              <EffectiveDate>2014-10-16T14:08:37</EffectiveDate>
              <ExpirationDate/>
          </SubParameters>
          <ActivityInfo/>
      </Request>
  </DATA>
  </TRX>') general_data from dual;


 insert into SAMPLE_TABLE (id_num,general_data)
  select  '0200' as id_num, 
  XMLTYPE('<TRX>
  <DATA>
      <Request APIType="null">
          <SubscriberIdsInfo>
              <ExternalId>
                  <ExternalId>0200</ExternalId>
              </ExternalId>
              <SubscriberId>
                  <SubscrNumber/>
              </SubscriberId>
          </SubscriberIdsInfo>
          <UpdateServices>
              <Soc>ABC</Soc>
              <ServiceAgreementSequenceNo/>
              <DealerCode/>
              <DeployMode/>
              <EffectiveDate>2019-10-16T00:00:00</EffectiveDate>
              <ExpirationDate/>
              <OfferInstanceId/>
          </UpdateServices>
          <SubParameters>
              <Name>PoolID</Name>
              <Values>POOL0200</Values>
              <EffectiveDate>2014-10-16T14:08:37</EffectiveDate>
              <ExpirationDate/>
          </SubParameters>
          <ActivityInfo/>
      </Request>
  </DATA>
  </TRX>') general_data from dual;


commit;

但是,我认为在这种情况下,您根本不需要检索会话。您可以轻松使用datePickerDialog = new DatePickerDialog(this,AlertDialog.THEME_HOLO_LIGHT, new DatePickerDialog.OnDateSetListener() { @Override public void onDateSet(DatePicker view, int year, int month, int dayOfMonth) { DateFormat dateFormat = new SimpleDateFormat("MM/dd/yyyy"); USERBIRTHDATE = dateFormat.format(new Date(year - 1900, month, dayOfMonth)); dateTxt.setText(USERBIRTHDATE); flag = true; } }, 1997, 0, 01); import keras.backend as K sess = K.get_session() 之类的后端功能来实现您的目标。