Skip to content

[Feature] vllm inferencer and memory safe vllm inferencer #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ tokenizers>=0.13.3
peft>=0.10.0
torch>=2.0.1
wandb==0.14.0
deepspeed==0.10.0
deepspeed<=0.14.0
trl>=0.7.11
sentencepiece
transformers>=4.31.0
Expand All @@ -18,8 +18,8 @@ scikit-learn==1.2.2
lm-eval==0.3.0
dill<0.3.5
bitsandbytes>=0.40.0
pydantic<=1.10.9
pydantic
gradio
accelerate>=0.27.2
einops>=0.6.1
scikit-learn==1.2.2
vllm>=0.4.1
116 changes: 111 additions & 5 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
extracted from the MODEL_CONFIG_CLASSES.
"""
import logging
from dataclasses import dataclass, field
from typing import Optional, List

from transformers.utils.versions import require_version
from dataclasses import dataclass, field, fields, Field, make_dataclass
from pathlib import Path
from typing import Optional, List, Union, Dict

from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
TrainingArguments,
)
from transformers.utils.versions import require_version

MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
Expand Down Expand Up @@ -308,6 +308,7 @@ class ModelArguments:
"choices": ["right", "left", "auto"],
}
)


def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
Expand Down Expand Up @@ -838,6 +839,40 @@ class InferencerArguments:

repetition_penalty : float
An argument of model.generate in huggingface to penalize repetitions.
use_beam_search : Optional[bool]
Whether to use beam search during inference, By default False.
num_output_sequences : Optional[int]
Number of output sequences to return for the given prompt,
currently only used in vllm inference, By default 8.
top_p : Optional[float]
top_p for sampling, By default 1.0.
top_k : Optional[int]
top_k for sampling, By default -1 (no top_k).
additional_stop_token_ids : Optional[List[int]]
the ids of the end of sentence tokens, By default [].
apply_chat_template : Optional[bool]
Whether to apply chat template, By default True.
save_results : Optional[bool]
Whether to save inference results, By default False.
results_path : Optional[str]
The **json file** path of inference results, By default None.
memory_safe_vllm_inference_detokenize : Optional[bool]
Whether to detokenize the memory safe vllm inference results.

NOTE: For iterative align pipelines, whether to detokenize depends on
the homogeneity of the policy model and the reward model
(i.e., if they have the same tokenizer).
The reason why `detokenize` for memory safe vllm inference is
included in args is due to the its implementation (i.e., subprocess
rather than within the python codes, thus have to communicate through
command line arguments).
use_vllm: bool, optional
Whether to use VLLM for inference, By default False.
vllm_tensor_parallel_size: int, optional
The tensor parallel size for VLLM inference.
vllm_gpu_memory_utilization: float, optional
The GPU memory utilization for VLLM inference. The proportion of GPU
memory (per GPU) to use for VLLM inference.
"""
device: str = field(
default="gpu",
Expand Down Expand Up @@ -902,6 +937,69 @@ class InferencerArguments:
use_accelerator: bool = field(
default=False, metadata={"help": "Whether to use Huggingface Accelerator instead of Deepspeed"},
)
use_beam_search: Optional[bool] = field(
default=False,
metadata={"help": "whether to use beam search during inference."},
)
num_output_sequences: Optional[int] = field(
default=8,
metadata={"help": (
"number of output sequences to return for the given prompt, "
"currently only used in vllm inference."
)},
)
top_p: Optional[float] = field(
default=1.0,
metadata={"help": "top_p for sampling."},
)
top_k: Optional[int] = field(
default=-1,
metadata={"help": "top_k for sampling."},
)
additional_stop_token_ids: Optional[List[int]] = field(
default_factory=lambda: [],
metadata={"help": "the ids of the end of sentence tokens"},
)
apply_chat_template: Optional[bool] = field(
default=True,
metadata={"help": "whether to apply chat template"},
)
memory_safe_vllm_inference_detokenize: Optional[bool] = field(
default=False,
metadata={"help": "Whether to detokenize the memory safe vllm inference results."},
)

# vllm inference args
use_vllm: bool = field(
default=False,
metadata={"help": "Whether to use VLLM for inference, By default False."}
)
vllm_tensor_parallel_size: Optional[int] = field(
default=1,
metadata={"help": "The tensor parallel size for VLLM inference."}
)
vllm_gpu_memory_utilization: Optional[float] = field(
default=0.95,
metadata={"help": "The GPU memory utilization for VLLM inference."}
)

# Args for result saving
save_results: Optional[bool] = field(
default=False, metadata={"help": "Whether to save inference results."}
)
results_path: Optional[str] = field(
default=None, metadata={"help": "The path of inference results."}
)

def __post_init__(self):
if self.save_results:
if self.results_path is None:
raise ValueError("Need to specify results_path when save_results is True.")
else:
if not self.results_path.endswith(".json"):
raise ValueError("The results_path must be a json file.")
else:
Path(self.results_path).parent.mkdir(parents=True, exist_ok=True)


@dataclass
Expand Down Expand Up @@ -1144,13 +1242,21 @@ class DPOAlignerArguments:
)


@dataclass
class IterativeAlignerArguments(InferencerArguments):
"""
Arguments for iterative aligners.
"""
pass


PIPELINE_ARGUMENT_MAPPING = {
"finetuner": FinetunerArguments,
"evaluator": EvaluatorArguments,
"inferencer": InferencerArguments,
"raft_aligner": RaftAlignerArguments,
"dpo_aligner": DPOAlignerArguments,
"rm_tuner": RewardModelingArguments
"rm_tuner": RewardModelingArguments,
}


Expand Down
10 changes: 5 additions & 5 deletions src/lmflow/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Dataset:
kwargs : Optional.
Keyword arguments.
"""
def __init__(self, data_args=None, backend: str="huggingface", *args, **kwargs):
def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface", *args, **kwargs):
self.data_args = data_args
self.backend = backend
self.backend_dataset = None
Expand Down Expand Up @@ -263,7 +263,7 @@ def from_dict(self, dict_obj: dict, *args, **kwargs):
return self
else:
raise NotImplementedError(
f'Currently .from_dict is not supported for backend "{backend}"'
f'Currently .from_dict is not supported for backend "{self.backend}"'
)


Expand Down Expand Up @@ -331,7 +331,7 @@ def to_dict(self):
return dict_obj
else:
raise NotImplementedError(
f'Current .to_dict is not supported for backend "{backend}"'
f'Current .to_dict is not supported for backend "{self.backend}"'
)


Expand All @@ -347,7 +347,7 @@ def to_list(self):
return instance_list
else:
raise NotImplementedError(
f'Current .to_list is not supported for backend "{backend}"'
f'Current .to_list is not supported for backend "{self.backend}"'
)


Expand Down Expand Up @@ -376,7 +376,7 @@ def map(self, *args, **kwargs):
else:
# If the backend is not Hugging Face, raise a NotImplementedError
raise NotImplementedError(
f'Currently .map is not supported for backend "{backend}"'
f'Currently .map is not supported for backend "{self.backend}"'
)


Expand Down
Loading
Loading