我正在使用CNTK作为Keras的后端。我正在尝试使用我在C ++中使用Keras训练过的模型。
我使用了HDF5中的Keras训练并保存了我的模型。我现在如何使用CNTK API将其保存为model-v2格式?
我试过了:
model = load_model('model2.h5')
cntk.ops.functions.Function.save(model, 'CNTK_model2.pb')
但我收到以下错误:
TypeError: save() missing 1 required positional argument: 'filename'
如果tensorflow是后端,我会这样做:
model = load_model('model2.h5')
sess = K.get_session()
tf_saver = tf.train.Saver()
tf_saver.save(sess=sess, save_path=checkpoint_path)
我怎样才能达到同样的目的?
答案 0 :(得分:3)
根据评论here,我能够使用它:
import cntk as C
import keras.backend as K
keras_model = K.load_model('my_keras_model.h5')
C.combine(keras_model.model.outputs).save('my_cntk_model')
cntk_model = C.load_model('my_cntk_model')
答案 1 :(得分:0)
你可以这样做
private int SMSMt(String pstrLoginName, String pstrServiceID, String
pstrCPID, String pstrMSISDN, String pstrKeyword, String pstrPriceCode,
String pstrChargeMSISDN, String pstrSubID, String pstrDstTrxID, String
pstrShortCode, String pstrSMS, String pstrLanguage, SubmitResult pobjRst)
{
int intResult = 0;
try
{
if (!this.mobjSMPP.mblnBound)
{
intResult = 9910;
}
else
{
SubmitSM objReq = new SubmitSM();
Address objSrcAddress = new Address();
Address objDstAddress = new Address();
objSrcAddress.setNpi((byte)1);
objSrcAddress.setTon((byte)0);
objSrcAddress.setAddress(pstrShortCode);
objDstAddress.setNpi((byte)1);
objDstAddress.setTon((byte)0);
objDstAddress.setAddress(pstrMSISDN);
objReq.setServiceType("");
objReq.setSourceAddr(objSrcAddress);
objReq.setDestAddr(objDstAddress);
objReq.setReplaceIfPresentFlag((byte)0);
objReq.setScheduleDeliveryTime("");
objReq.setValidityPeriod("");
objReq.setPriorityFlag((byte)1);
objReq.setRegisteredDelivery((byte)1);
objReq.setUserMessageReference((short)25);
WriteLog("pstrLanguage=" + pstrLanguage);
if (pstrLanguage.equals("1")) {
objReq.setDataCoding((byte)8);
WriteLog("Ori pstrSMS=" + pstrSMS);
objReq.setShortMessage(pstrSMS,"UTF_16BE");
//String HexStr = new String("A" + "\u00ea" + "\u00f1" + "\u00fc" + "\u0eaa" + "C");
//WriteLog("HexStr=" + HexStr);
//pstrSMS = stringToHex(pstrSMS);
//StringConverter TestRun = new StringConverter(pstrSMS);
/*pstrSMS = HexStr;*/
//pstrSMS = new String("\u0eaa");
//pstrSMS = new String("0x0eaa");
/*pstrSMS = "世界您好";
WriteLog("Ori pstrSMS=" + pstrSMS);
//byte[] messageData = new ASCIIEncoding().encode(pstrSMS);
byte[] b = s.getBytes(StandardCharsets.US_ASCII);*/
WriteLog("Ori pstrSMS [unicode]=" + pstrSMS);
//byte[] textByte = pstrSMS.getBytes("UTF-16BE");
//WriteLog("HexCode [encoded with UTF-16BE]= " + textByte);
//pstrSMS = textByte.toString(); //cannot convert byte[] to string
}
else {
objReq.setDataCoding((byte)0);
objReq.setShortMessage(pstrSMS);
}
我在这里假设你已经调用了model.compile(即那是我尝试的唯一案例: - )
答案 2 :(得分:0)
您看到此错误的原因是因为keras的cntk后端使用用户定义的函数在批处理轴上进行重新整形,而无法对其进行序列化。我们已经在CNTK v2.2中解决了这个问题。请将您的cntk升级到v2.2,并将keras升级到last master。 请看这个拉取请求: https://github.com/fchollet/keras/pull/7907