Skip to content

Commit 939ce40

Browse files
authored
Update torchpippy (#2938)
* rm warning * Take 3 * Take 4 * Annotate * Take 6 * Updated * Spec * Last fix * Don't padd input * Finished * Continue refactor * Rm comment * Adjust the err * Start adjustment * GPT2 works, T5 does not * llama too now I think * Flag the t5 example
1 parent c212092 commit 939ce40

File tree

6 files changed

+53
-29
lines changed

6 files changed

+53
-29
lines changed

examples/inference/pippy/bert.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
input = torch.randint(
3333
low=0,
3434
high=model.config.vocab_size,
35-
size=(2, 512), # bs x seq_len
35+
size=(1, 512), # bs x seq_len
3636
device="cpu",
3737
dtype=torch.int64,
3838
requires_grad=False,
@@ -49,6 +49,16 @@
4949
# available on all GPUs
5050
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)
5151

52+
# Create new inputs of the expected size (n_processes)
53+
input = torch.randint(
54+
low=0,
55+
high=model.config.vocab_size,
56+
size=(2, 512), # bs x seq_len
57+
device="cpu",
58+
dtype=torch.int64,
59+
requires_grad=False,
60+
)
61+
5262
# Move the inputs to the first device
5363
input = input.to("cuda:0")
5464

examples/inference/pippy/gpt2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
input = torch.randint(
3333
low=0,
3434
high=model.config.vocab_size,
35-
size=(2, 1024), # bs x seq_len
35+
size=(1, 1024), # bs x seq_len
3636
device="cpu",
3737
dtype=torch.int64,
3838
requires_grad=False,
@@ -48,6 +48,16 @@
4848
# available on all GPUs
4949
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)
5050

51+
# Create new inputs of the expected size (n_processes)
52+
input = torch.randint(
53+
low=0,
54+
high=model.config.vocab_size,
55+
size=(2, 1024), # bs x seq_len
56+
device="cpu",
57+
dtype=torch.int64,
58+
requires_grad=False,
59+
)
60+
5161
# Move the inputs to the first device
5262
input = input.to("cuda:0")
5363

examples/inference/pippy/llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# Input configs
2828
# Create example inputs for the model
2929
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
30-
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
30+
prompts = ("I would like to", "I really like to") # bs = 2, sending 2 per process
3131
tokenizer.pad_token = tokenizer.eos_token
3232
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
3333

@@ -43,6 +43,8 @@
4343

4444
# currently we don't support `model.generate`
4545
# output = model.generate(**inputs, max_new_tokens=1)
46+
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
47+
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
4648
inputs = inputs.to(0)
4749
with torch.no_grad():
4850
output = model(**inputs)

examples/inference/pippy/t5.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,21 @@
1414
import time
1515

1616
import torch
17+
from packaging import version
1718
from transformers import AutoModelForSeq2SeqLM
1819

1920
from accelerate import PartialState, prepare_pippy
21+
from accelerate import __version__ as accelerate_version
2022
from accelerate.utils import set_seed
2123

2224

25+
if version.parse(accelerate_version) > version.parse("0.33.0"):
26+
raise RuntimeError(
27+
"Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. "
28+
"Please use a lower accelerate version and `torchpippy`, which this example uses."
29+
)
30+
31+
2332
# Set the random seed to have reproducable outputs
2433
set_seed(42)
2534

src/accelerate/inference.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,22 +79,21 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
7979
`AcceleratorState.num_processes`
8080
"""
8181
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
82-
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
83-
from pippy.PipelineStage import PipelineStage
82+
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline
8483

8584
# We need to annotate the split points in the model for PiPPy
8685
state = PartialState()
87-
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
88-
found_batch_size = find_pippy_batch_size(args, kwargs)
89-
if found_batch_size != num_chunks:
90-
if args is not None:
91-
args = pad_input_tensors(args, found_batch_size, num_chunks)
92-
if kwargs is not None:
93-
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
94-
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
95-
stage = PipelineStage(pipe, state.local_process_index, device=state.device)
86+
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
87+
pipe = pipeline(
88+
model,
89+
mb_args=args,
90+
mb_kwargs=kwargs,
91+
split_spec=split_spec,
92+
)
93+
stage = pipe.build_stage(state.local_process_index, device=state.device)
94+
schedule = ScheduleGPipe(stage, num_chunks)
9695

97-
return stage
96+
return schedule
9897

9998

10099
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
@@ -143,22 +142,20 @@ def prepare_pippy(
143142
no_split_module_classes (`List[str]`):
144143
A list of class names for layers we don't want to be split.
145144
example_args (tuple of model inputs):
146-
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
145+
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
146+
this method if possible.
147147
example_kwargs (dict of model inputs)
148-
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
149-
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
150-
is true for all cases.
148+
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
149+
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not
150+
recommended unless the prior condition is true for all cases.
151151
num_chunks (`int`, defaults to the number of available GPUs):
152152
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
153153
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
154154
gather_output (`bool`, defaults to `False`):
155155
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
156156
"""
157157
if not is_pippy_available():
158-
raise ImportError(
159-
"`pippy` was not found to be installed on your system. Please "
160-
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
161-
)
158+
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
162159
state = PartialState()
163160
example_args = send_to_device(example_args, "cpu")
164161
example_kwargs = send_to_device(example_kwargs, "cpu")
@@ -177,7 +174,7 @@ def prepare_pippy(
177174
model.hf_split_points = split_points
178175

179176
def forward(*args, **kwargs):
180-
return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs)
177+
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)
181178

182179
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
183180
# Note: creates an infinite recursion loop with `generate`

src/accelerate/utils/imports.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,7 @@ def is_deepspeed_available():
178178

179179

180180
def is_pippy_available():
181-
package_exists = _is_package_available("pippy", "torchpippy")
182-
if package_exists:
183-
pippy_version = version.parse(importlib.metadata.version("torchpippy"))
184-
return compare_versions(pippy_version, ">", "0.1.1")
185-
return False
181+
return is_torch_version(">=", "2.4.0")
186182

187183

188184
def is_bf16_available(ignore_tpu=False):

0 commit comments

Comments
 (0)