-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Description
I have been trying many different ways, but I'm still stuck trying to combine a Pipeline (from pytorch-nightly) with DeepSpeed.
I adapted MPU from https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM/mpu, which by running some tests on seems to produce what I need.
But I can't figure out how to launch DeepSpeed to make things work.
Setup: 4 gpus
2 model parallel groups: [g0, g1], [g2, g3]
2 data parallel groups: [g0, g2], [g1, g3]
Inside the pipeline I correctly switch model layers to either [g0, g1] or [g2, g3].
I pass mpu to deepspeed init as suggested in the docs:
mpu.initialize_model_parallel(n_gpus_per_mp)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
args=SimpleNamespace(**ds_args), # expects an obj
model=model,
model_parameters=model_parameters,
config_params=config,
mpu = mpu,
)
I init the mpu with 2 for 2 gpus per MP. mpu.initialize_model_parallel(2)
Now how do I launch deepspeep to use my PP and to achive 2D Parallelism? Do I tell it about 4 gpus or just gpus 0 and 2?
- if I tell it about 4 gpus (well, letting it see all gpus) - it launches 4 processes which I think is wrong, and it crashes with:
CUDA_VISIBLE_DEVICES=0,1,2,3, deepspeed program
[...]
File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 256, in forward
return self.weight * hidden_states
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:1!
gpus 2 and 1 shouldn't end up together in the pipeline - they are in the wrong groups.
- if I tell it only about the gpus 0+2 it launches 2 processes, but then crashes with:
CUDA_VISIBLE_DEVICES=0,1,2,3, deepspeed --include localhost:0,2 program
0%| | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
File "./finetune_trainer.py", line 373, in <module>
main()
File "./finetune_trainer.py", line 303, in main
train_result = trainer.train(
File "/home/stas/hf/transformers/src/transformers/trainer.py", line 919, in train
tr_loss += self.training_step(model, inputs)
File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1286, in training_step
loss = self.compute_loss(model, inputs)
File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1316, in compute_loss
outputs = model(**inputs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 702, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/hf/deepspeed/deepspeed/runtime/engine.py", line 824, in forward
loss = self.module(*inputs, **kwargs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1833, in forward
encoder_outputs = self.encoder(
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1144, in forward
layers[layer_id].to(device_id)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 673, in to
return self._apply(convert)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 387, in _apply
module._apply(fn)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 387, in _apply
module._apply(fn)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 387, in _apply
module._apply(fn)
[Previous line repeated 2 more times]
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 409, in _apply
param_applied = fn(param)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 671, in convert
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
RuntimeError: CUDA error: invalid device ordinal
Traceback (most recent call last):
File "./finetune_trainer.py", line 373, in <module>
main()
File "./finetune_trainer.py", line 303, in main
train_result = trainer.train(
File "/home/stas/hf/transformers/src/transformers/trainer.py", line 919, in train
tr_loss += self.training_step(model, inputs)
File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1286, in training_step
loss = self.compute_loss(model, inputs)
File "/home/stas/hf/transformers/src/transformers/trainer.py", line 1316, in compute_loss
outputs = model(**inputs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 702, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/hf/deepspeed/deepspeed/runtime/engine.py", line 824, in forward
loss = self.module(*inputs, **kwargs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1833, in forward
encoder_outputs = self.encoder(
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/hf/transformers/src/transformers/models/t5/modeling_t5.py", line 1167, in forward
outputs = block_pipe(inputs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/distributed/pipeline/sync/pipe.py", line 366, in forward
return RRef(output)
RuntimeError: agent INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/distributed/rpc/rpc_agent.cpp":247, please report a bug to PyTorch. Current RPC agent is not set!
- Running direct tests on your MPU code it has to see the world of size 4 for this to work, but I can only do it by exposing 4 gpus and end up with 4 processes, which crash. mpu functions return wrong groups if the world is size 2.
Somehow I need to tell deepspeed that there are 4 gpus to work with but that I'm using a Pipeline.
So I'm stuck :(
Please help!
Thank you!
p.s. do you by chance have some sort of slack where quick questions of this kind could be asked? I'm sure the answer is very simple, but I couldn't find it anywhere in docs. some of us ask question on pytorch-slack which you will be very gladly invited to I believe - I'm not an admin, but I would be happy to ask if you email me your email and I will invite you - I'm at [email protected].