我想知道如何从使用Keras API构建的模型中获取tf.gradients
。
import Tensorflow as tf
from tensorflow import keras
from sklearn.datasets.samples_generator import make_blobs
# Create the model
inputs = keras.Input(shape=(2,))
x = keras.layers.Dense(12, activation='relu')(inputs)
x = keras.layers.Dense(8, activation='relu')(x)
predictions = keras.layers.Dense(3, activation='softmax')(x)
model = keras.models.Model(inputs=inputs, outputs=predictions)
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Generate random data
X, y = make_blobs(n_samples=1000, centers=3, n_features=2)
labels = keras.utils.to_categorical(y, num_classes=3)
# Compute the gradients wrt inputs
y_true = tf.convert_to_tensor(labels)
y_pred = tf.convert_to_tensor(np.round(model.predict(X)))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
grads = tf.gradients(model.loss_functions[0](y_true, y_pred),
model.inputs[0])
sess.run(grads, input_dict={model.inputs[0]: X, model.outputs: y})
以上首次尝试:我的毕业证书是None
。在下面进行第二次尝试:
sess.run(grads, input_dict={model.inputs: X, model.outputs: y })
我收到以下错误:
TypeError: unhashable type: 'list'
答案 0 :(得分:0)
我认为您在使用Keras时不应直接与Tensorflow创建新会话。相反,最好使用Keras隐式创建的会话:
CREATE TABLE SAMPLE_TABLE AS
select '0100' as id_num,
XMLTYPE('<TRX>
<DATA>
<Request APIType="null">
<SubscriberIdsInfo>
<ExternalId>
<ExternalId>0100</ExternalId>
</ExternalId>
<SubscriberId>
<SubscrNumber/>
</SubscriberId>
</SubscriberIdsInfo>
<UpdateServices>
<Soc>ABC</Soc>
<ServiceAgreementSequenceNo/>
<DealerCode/>
<DeployMode/>
<EffectiveDate>2019-10-16T00:00:00</EffectiveDate>
<ExpirationDate/>
<OfferInstanceId/>
</UpdateServices>
<AddServices>
<Soc>ABC1</Soc>
<ServiceAgreementSequenceNo/>
<DealerCode/>
<DeployMode/>
<EffectiveDate>2018-10-16T00:00:00</EffectiveDate>
<ExpirationDate/>
<OfferInstanceId/>
</AddServices>
<RemoveServices>
<Soc>ABC2</Soc>
<ServiceAgreementSequenceNo/>
<DealerCode/>
<DeployMode/>
<EffectiveDate>2017-10-16T00:00:00</EffectiveDate>
<ExpirationDate/>
<OfferInstanceId/>
</RemoveServices>
<SubParameters>
<Name>PoolID</Name>
<Values>POOL0100</Values>
<EffectiveDate>2014-10-16T14:08:37</EffectiveDate>
<ExpirationDate/>
</SubParameters>
<ActivityInfo/>
</Request>
</DATA>
</TRX>') general_data from dual;
insert into SAMPLE_TABLE (id_num,general_data)
select '0200' as id_num,
XMLTYPE('<TRX>
<DATA>
<Request APIType="null">
<SubscriberIdsInfo>
<ExternalId>
<ExternalId>0200</ExternalId>
</ExternalId>
<SubscriberId>
<SubscrNumber/>
</SubscriberId>
</SubscriberIdsInfo>
<UpdateServices>
<Soc>ABC</Soc>
<ServiceAgreementSequenceNo/>
<DealerCode/>
<DeployMode/>
<EffectiveDate>2019-10-16T00:00:00</EffectiveDate>
<ExpirationDate/>
<OfferInstanceId/>
</UpdateServices>
<SubParameters>
<Name>PoolID</Name>
<Values>POOL0200</Values>
<EffectiveDate>2014-10-16T14:08:37</EffectiveDate>
<ExpirationDate/>
</SubParameters>
<ActivityInfo/>
</Request>
</DATA>
</TRX>') general_data from dual;
commit;
但是,我认为在这种情况下,您根本不需要检索会话。您可以轻松使用datePickerDialog = new DatePickerDialog(this,AlertDialog.THEME_HOLO_LIGHT, new DatePickerDialog.OnDateSetListener() {
@Override
public void onDateSet(DatePicker view, int year, int month, int dayOfMonth) {
DateFormat dateFormat = new SimpleDateFormat("MM/dd/yyyy");
USERBIRTHDATE = dateFormat.format(new Date(year - 1900, month, dayOfMonth));
dateTxt.setText(USERBIRTHDATE);
flag = true;
}
}, 1997, 0, 01);
和import keras.backend as K
sess = K.get_session()
之类的后端功能来实现您的目标。