我正在尝试使用自己的数据在MNIST上复制PySyft异步FedML教程。虽然我已经做到了,但现在遇到了无法在两个不同设备上训练和托管数据的问题。当我在一台设备上运行该教程时,一切都会顺利进行。但是,在两个设备上,我从元组大小不匹配中得到以下ValueError:
ValueError Traceback (most recent call last)
<ipython-input-24-40944cb95694> in <module>
5 logger.info("Training round %s/%s", curr_round, args.training_rounds)
6
----> 7 results = await asyncio.gather(
8 *[
9 rwc.fit_model_on_worker(
~/Documents/async_learning/run_websocket_client.py in fit_model_on_worker(worker, traced_model, batch_size, curr_round, max_nr_batches, lr)
368 )
369 train_config.send(worker)
--> 370 loss = await worker.async_fit(dataset_key="bank", return_ids=[0])
371 model = train_config.model_ptr.get().obj
372 return worker.id, model, loss
/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/workers/websocket_client.py in async_fit(self, dataset_key, device, return_ids)
173
174 # Return the deserialized response.
--> 175 return sy.serde.deserialize(response)
176
177 def fit(self, dataset_key: str, **kwargs):
/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/serde.py in deserialize(binary, worker, strategy)
67 object: the deserialized form of the binary input.
68 """
---> 69 return strategy(binary, worker)
/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/serde.py in deserialize(binary, worker)
380
381 simple_objects = _deserialize_msgpack_binary(binary, worker)
--> 382 return _deserialize_msgpack_simple(simple_objects, worker)
383
384
/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/serde.py in _deserialize_msgpack_simple(simple_objects, worker)
371 # as msgpack's inability to serialize torch tensors or ... or
372 # python slice objects
--> 373 return _detail(worker, simple_objects)
374
375
/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/serde.py in _detail(worker, obj, **kwargs)
497 """
498 if type(obj) in (list, tuple):
--> 499 val = detailers[obj[0]](worker, obj[1], **kwargs)
500 return _detail_field(obj[0], val)
501 else:
/opt/anaconda3/envs/fedML/lib/python3.8/site-packages/syft/serde/msgpack/torch_serde.py in _detail_torch_tensor(worker, tensor_tuple)
180 """
181
--> 182 (
183 tensor_id,
184 tensor_bin,
ValueError: not enough values to unpack (expected 9, got 7)
关于为什么可能会出现这种情况的任何想法?我相信这与run_websocket_client.py
中的以下行有关:
loss = await worker.async_fit(dataset_key="bank", return_ids=[0])