@@ -79,22 +79,21 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
79
79
`AcceleratorState.num_processes`
80
80
"""
81
81
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
82
- from pippy .IR import Pipe , PipeSplitWrapper , annotate_split_points
83
- from pippy .PipelineStage import PipelineStage
82
+ from torch .distributed .pipelining import ScheduleGPipe , SplitPoint , pipeline
84
83
85
84
# We need to annotate the split points in the model for PiPPy
86
85
state = PartialState ()
87
- annotate_split_points ( model , {split_point : PipeSplitWrapper . SplitPoint .BEGINNING for split_point in split_points })
88
- found_batch_size = find_pippy_batch_size ( args , kwargs )
89
- if found_batch_size != num_chunks :
90
- if args is not None :
91
- args = pad_input_tensors ( args , found_batch_size , num_chunks )
92
- if kwargs is not None :
93
- kwargs = pad_input_tensors ( kwargs , found_batch_size , num_chunks )
94
- pipe = Pipe . from_tracing ( model , num_chunks = num_chunks , example_args = args , example_kwargs = kwargs )
95
- stage = PipelineStage ( pipe , state . local_process_index , device = state . device )
86
+ split_spec = {split_point : SplitPoint .BEGINNING for split_point in split_points }
87
+ pipe = pipeline (
88
+ model ,
89
+ mb_args = args ,
90
+ mb_kwargs = kwargs ,
91
+ split_spec = split_spec ,
92
+ )
93
+ stage = pipe . build_stage ( state . local_process_index , device = state . device )
94
+ schedule = ScheduleGPipe ( stage , num_chunks )
96
95
97
- return stage
96
+ return schedule
98
97
99
98
100
99
def pippy_forward (forward , num_chunks , gather_output , * args , ** kwargs ):
@@ -143,22 +142,20 @@ def prepare_pippy(
143
142
no_split_module_classes (`List[str]`):
144
143
A list of class names for layers we don't want to be split.
145
144
example_args (tuple of model inputs):
146
- The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
145
+ The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
146
+ this method if possible.
147
147
example_kwargs (dict of model inputs)
148
- The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
149
- that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
150
- is true for all cases.
148
+ The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
149
+ *highly* limiting structure that requires the same keys be present at *all* inference calls. Not
150
+ recommended unless the prior condition is true for all cases.
151
151
num_chunks (`int`, defaults to the number of available GPUs):
152
152
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
153
153
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
154
154
gather_output (`bool`, defaults to `False`):
155
155
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
156
156
"""
157
157
if not is_pippy_available ():
158
- raise ImportError (
159
- "`pippy` was not found to be installed on your system. Please "
160
- "install using `pip install torchpippy` or ensure you have at least version 0.2.0"
161
- )
158
+ raise ImportError ("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later." )
162
159
state = PartialState ()
163
160
example_args = send_to_device (example_args , "cpu" )
164
161
example_kwargs = send_to_device (example_kwargs , "cpu" )
@@ -177,7 +174,7 @@ def prepare_pippy(
177
174
model .hf_split_points = split_points
178
175
179
176
def forward (* args , ** kwargs ):
180
- return pippy_forward (stage .forward , num_chunks , gather_output , * args , ** kwargs )
177
+ return pippy_forward (stage .step , num_chunks , gather_output , * args , ** kwargs )
181
178
182
179
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
183
180
# Note: creates an infinite recursion loop with `generate`
0 commit comments