PySyft异步学习-培训中的Websocket错误(元组不匹配)

时间:2020-05-06 16:17:20

标签: python websocket pytorch federated-learning pysyft

我正在尝试使用自己的数据在MNIST上复制PySyft异步FedML教程。虽然我已经做到了,但现在遇到了无法在两个不同设备上训练和托管数据的问题。当我在一台设备上运行该教程时,一切都会顺利进行。但是,在两个设备上,我从元组大小不匹配中得到以下ValueError:

ValueError                                Traceback (most recent call last)
<ipython-input-24-40944cb95694> in <module>
      5     logger.info("Training round %s/%s", curr_round, args.training_rounds)
      6 
----> 7     results = await asyncio.gather(
      8         *[
      9             rwc.fit_model_on_worker(

~/Documents/async_learning/run_websocket_client.py in fit_model_on_worker(worker, traced_model, batch_size, curr_round, max_nr_batches, lr)
    368     )
    369     train_config.send(worker)
--> 370     loss = await worker.async_fit(dataset_key="bank", return_ids=[0])
    371     model = train_config.model_ptr.get().obj
    372     return worker.id, model, loss

/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/workers/websocket_client.py in async_fit(self, dataset_key, device, return_ids)
    173 
    174         # Return the deserialized response.
--> 175         return sy.serde.deserialize(response)
    176 
    177     def fit(self, dataset_key: str, **kwargs):

/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/serde.py in deserialize(binary, worker, strategy)
     67         object: the deserialized form of the binary input.
     68     """
---> 69     return strategy(binary, worker)

/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/serde.py in deserialize(binary, worker)
    380 
    381     simple_objects = _deserialize_msgpack_binary(binary, worker)
--> 382     return _deserialize_msgpack_simple(simple_objects, worker)
    383 
    384 

/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/serde.py in _deserialize_msgpack_simple(simple_objects, worker)
    371     # as msgpack's inability to serialize torch tensors or ... or
    372     # python slice objects
--> 373     return _detail(worker, simple_objects)
    374 
    375 

/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/serde.py in _detail(worker, obj, **kwargs)
    497     """
    498     if type(obj) in (list, tuple):
--> 499         val = detailers[obj[0]](worker, obj[1], **kwargs)
    500         return _detail_field(obj[0], val)
    501     else:

/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/torch_serde.py in _detail_torch_tensor(worker, tensor_tuple)
    180     """
    181 
--> 182     (
    183         tensor_id,
    184         tensor_bin,

ValueError: not enough values to unpack (expected 9, got 7)

关于为什么可能会出现这种情况的任何想法?我相信这与run_websocket_client.py中的以下行有关: loss = await worker.async_fit(dataset_key="bank", return_ids=[0])

0 个答案:

没有答案