-
Notifications
You must be signed in to change notification settings - Fork 835
[Feature] vllm inferencer and memory safe vllm inferencer #860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supporting VLLM is important to accelerate the inferencing component in different algorithms. Some modifications may be needed before merging into main branch.
requirements.txt
- [Feature] line 8:
deepspeed <= 0.14.0
to ensure backward compatibility.
src/lmflow/args.py
- [Style] line 15: we majorly sort imported packages alphabetically. Moving this to line 16 would be better.
- [Architecture] line 99, 318: the implication of this argument
load_on_init
seems confusing to users. - [Architecture] line 318-335: these arguments belong to
Inferencer
, notModel
. Should move them toInferencerArguments
. If model need these arguments, they can be passed in as**kwargs
. - [Style] line 949-1001: if these options are for vllm only, better append a prefix
vllm_
. Or implementing the features corresponding to those arguments is another option. - [Feature] line 976: better automatically detect
os.environ[CUDA_VISIBLE_DEVICES]
.
src/lmflow/models/hf_decoder_model.py
- [Architecture] line 377, 429, 471: add argument
use_vllm
, which is passed fromInferencer
.
src/lmflow/models/hf_model_mixin.py
- [Architecture] line 111: pass from
Inferencer
, specify this as an extra argument for__init__
. - [Architecture] line 368-419: The indentation level is too high now, consider wrap this part of code in a separated function.
- [Architecture] line 453:
LLM
should not beself.backend_model
, should have another variable, such asself.backend_model_for_inference
, otherwise it will mess up with other usages withself.backend_model
. - [Question] line 453: Does vllm support dynamic model change during inference?
src/lmflow/pipeline/inferencerv2.py
- We can rename it as
vllm_inferencer.py
. This matches the classname. Also, v2 is vague and confusing.
src/lmflow/pipeline/utils/collections.py
- [Style] Better rename it. The name
collection
is vague and confusing. - [Architecture] line 15: This is util function for models, move it to
src/lmflow/utils/model.py
.src/lmflow/pipeline/utils/
are majorly for customized training classes such asraft_trainer
. - [Architecture] line 28: This is util function for datasets, move it to
src/lmflow/utils/dataset.py
.
src/lmflow/pipeline/utils/memory_safe_vllm_inference.py
- [Arcthecture] Move it to
examples/memory_safe_vllm_inference.py
, or make it a special mode of the common inference, like a mode that can be activated by providing a single option of--use_vllm
.
src/lmflow/utils/collections.py
- [Architecture] Move the content to
src/lmflow/utils/dataset.py
tests/pipeline/test_memory_safe_vllm_inferencer.py
- [Style] line 16, 23, 34: there are absolute paths, consider uploading the dataset and use huggingface model names.
Changes made, test to be done.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can record the TODO features in a roadmap issue. Others look good to me.
src/lmflow/utils/args.py`
- [Architecture] The name
args.py
is not preferred, as it is usually used for commandline arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
Description
We perform vllm inferencer and memory safe vllm inferencer, which will benefit online rlhf process.
MemorySafeVLLMInferencer
runslmflow/pipeline/utils/memory_safe_vllm_inference.py
using pythonsubprocess
, since it's not able to offload model or release GPU that vLLM takes within a python script usingdel
,model.to('cpu')
or other approaches currently. (see this issue)Tests
MemorySafeVLLMInferencer
runtime

test result

Compatibility
run_reward_modeling.sh

run_finetune.sh

run_finetune_with_lora.sh
