训练、保存模型和加载:加载模型时出错

时间:2021-06-09 21:15:54

标签: tensorflow keras

我正在使用 TensorFlow 和 Keras 训练模型。我想保存模型然后加载它。但我遇到了一些错误。

我是这样编译模型的:

function get_address_withLocation(lat, lng, map, is_origin) {
    infoWindow = new google.maps.InfoWindow();
    const geocoder = new google.maps.Geocoder();
    const latlng = {
        lat: lat,
        lng: lng,
    };
    geocoder.geocode({ location: latlng }, (results, status) => {
        if (status === "OK") {
            if (results[0]) {
                map.setZoom(7);
                const marker = new google.maps.Marker({
                    position: latlng,
                    map: map,
                });
                contentString =
                    '<div id="content-map">' +
                    '<p>' + results[0].formatted_address + '</p>' +
                    "</div>";
                infoWindow.setContent(contentString);
                infoWindow.open(map, marker);


                if (is_origin) {
                    document.getElementById('origin_address').value = results[0].formatted_address;
                    console.log(lat);
                    console.log(lng);
                    document.getElementById("lat_Origin").value = lat;
                    document.getElementById("lon_Origin").value = lng;
                    // document.getElementById("lat_Origin").value = 40;
                    // document.getElementById("lon_Origin").value = 2;

                    remove_mapMarkers('origin_address', marker)


                } else {
                    console.log(lat);
                    console.log(lng);
                    document.getElementById('destination_address').value = results[0].formatted_address;
                    document.getElementById("lat_Dest").value = lat;
                    document.getElementById("lon_Dest").value = lng;
                    remove_mapMarkers('destination_address', marker);


                }



            } else {
                window.alert("No results found");
            }

        } else {
            window.alert("Geocoder failed due to: " + status);
        }
    });



}



在训练后我以这种方式加载模型:

from tensorflow.keras.models import Model
import tensorflow as tf

model.compile(loss='categorical_crossentropy',
              optimizer=adam,
              metrics=['accuracy', top3, top5])

所以我得到一个文件夹“model”,其中包含:

model.save('model')

最后,我尝试使用以下方法加载模型:

---model 
     ---assets
     ---variables
     ---keras_metadata.pb
     ---saved_model.pb

但我收到此错误:

import tensorflow as tf

new_model = tf.keras.models.load_model('model')
new_model.summary()

1 个答案:

答案 0 :(得分:1)

当您的模型使用自定义对象(例如自定义指标)时,您必须使用 custom_objectsload_model 参数指定它们:

new_model = tf.keras.models.load_model('model', custom_objects={'top3': top3, 'top5': top5})

请注意,自定义指标的定义必须在加载模型的同一模块/环境中可用。