Keras-CNTK保存模型-v2格式

时间:2017-06-06 15:34:29

标签: keras cntk

我正在使用CNTK作为Keras的后端。我正在尝试使用我在C ++中使用Keras训练过的模型。

我使用了HDF5中的Keras训练并保存了我的模型。我现在如何使用CNTK API将其保存为model-v2格式?

我试过了:

model = load_model('model2.h5')
cntk.ops.functions.Function.save(model, 'CNTK_model2.pb')

但我收到以下错误:

TypeError: save() missing 1 required positional argument: 'filename'

如果tensorflow是后端,我会这样做:

model = load_model('model2.h5')
sess = K.get_session()
tf_saver = tf.train.Saver()
tf_saver.save(sess=sess, save_path=checkpoint_path)

我怎样才能达到同样的目的?

3 个答案:

答案 0 :(得分:3)

根据评论here,我能够使用它:

import cntk as C
import keras.backend as K

keras_model = K.load_model('my_keras_model.h5')

C.combine(keras_model.model.outputs).save('my_cntk_model')
cntk_model = C.load_model('my_cntk_model')

答案 1 :(得分:0)

你可以这样做 private int SMSMt(String pstrLoginName, String pstrServiceID, String pstrCPID, String pstrMSISDN, String pstrKeyword, String pstrPriceCode, String pstrChargeMSISDN, String pstrSubID, String pstrDstTrxID, String pstrShortCode, String pstrSMS, String pstrLanguage, SubmitResult pobjRst) { int intResult = 0; try { if (!this.mobjSMPP.mblnBound) { intResult = 9910; } else { SubmitSM objReq = new SubmitSM(); Address objSrcAddress = new Address(); Address objDstAddress = new Address(); objSrcAddress.setNpi((byte)1); objSrcAddress.setTon((byte)0); objSrcAddress.setAddress(pstrShortCode); objDstAddress.setNpi((byte)1); objDstAddress.setTon((byte)0); objDstAddress.setAddress(pstrMSISDN); objReq.setServiceType(""); objReq.setSourceAddr(objSrcAddress); objReq.setDestAddr(objDstAddress); objReq.setReplaceIfPresentFlag((byte)0); objReq.setScheduleDeliveryTime(""); objReq.setValidityPeriod(""); objReq.setPriorityFlag((byte)1); objReq.setRegisteredDelivery((byte)1); objReq.setUserMessageReference((short)25); WriteLog("pstrLanguage=" + pstrLanguage); if (pstrLanguage.equals("1")) { objReq.setDataCoding((byte)8); WriteLog("Ori pstrSMS=" + pstrSMS); objReq.setShortMessage(pstrSMS,"UTF_16BE"); //String HexStr = new String("A" + "\u00ea" + "\u00f1" + "\u00fc" + "\u0eaa" + "C"); //WriteLog("HexStr=" + HexStr); //pstrSMS = stringToHex(pstrSMS); //StringConverter TestRun = new StringConverter(pstrSMS); /*pstrSMS = HexStr;*/ //pstrSMS = new String("\u0eaa"); //pstrSMS = new String("0x0eaa"); /*pstrSMS = "世界您好"; WriteLog("Ori pstrSMS=" + pstrSMS); //byte[] messageData = new ASCIIEncoding().encode(pstrSMS); byte[] b = s.getBytes(StandardCharsets.US_ASCII);*/ WriteLog("Ori pstrSMS [unicode]=" + pstrSMS); //byte[] textByte = pstrSMS.getBytes("UTF-16BE"); //WriteLog("HexCode [encoded with UTF-16BE]= " + textByte); //pstrSMS = textByte.toString(); //cannot convert byte[] to string } else { objReq.setDataCoding((byte)0); objReq.setShortMessage(pstrSMS); } 我在这里假设你已经调用了model.compile(即那是我尝试的唯一案例: - )

答案 2 :(得分:0)

您看到此错误的原因是因为keras的cntk后端使用用户定义的函数在批处理轴上进行重新整形,而无法对其进行序列化。我们已经在CNTK v2.2中解决了这个问题。请将您的cntk升级到v2.2,并将keras升级到last master。 请看这个拉取请求: https://github.com/fchollet/keras/pull/7907