diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index cec71f53d..07ab313b1 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -22,7 +22,7 @@ import logging import os -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -254,14 +254,15 @@ def from_model( # Instanciate the object without using __init__ self = cls.__new__(cls) - self.config = config self.transformers_config = model.config - self.generation_config_dict = config.generation_parameters.to_transformers_dict() + self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path) + if config is not None: + self.generation_config_dict = config.generation_parameters.to_transformers_dict() self._max_length = self._init_max_length() self._tokenizer = self._create_auto_tokenizer() - self.batch_size = config.batch_size + self.batch_size = getattr(config, "batch_size", None) self.model_name = _simplify_name(model.name_or_path) - self.model_sha = config.get_model_sha() + self.model_sha = self.config.get_model_sha() # If model_parallel is not set we compare the number of processes with the number of GPUs self.model = model @@ -508,7 +509,114 @@ def greedy_until_multi_turn( # noqa: C901 ) -> GenerativeMultiturnResponse: raise NotImplementedError("This method is not implemented for this model") - def greedy_until( + def _continious_greedy_until( + self, + requests: list[GreedyUntilRequest], + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerateReturn]: list of generated responses. + """ + for request in requests: + request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] + request.tokenized_context = self.tok_encode(request.context) + + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for split in tqdm( + dataset.splits_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, # self.disable_tqdm, + ): + # For chat models, generation stops with EOS token, so we don't need to specify stop tokens + if self.use_chat_template: + stop_tokens = [] + else: + # NOTE: we are assuming all items in a batch behave similarly (same + # stop_tokens and max_tokens genrated) which is not necessarily + # the case! Because of that we only use batch size of 1 + stop_tokens = split[0].stop_sequence + + max_new_tokens = self.config.generation_parameters.max_new_tokens or split[0].generation_size + returns_logits = split[0].use_logits + num_samples = split[0].num_samples + + context = [sample.context for sample in split] + tokenized = self.tokenizer(context, add_special_tokens=self.add_special_tokens) + + # The main question for this step is the following: + # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk + # of losing some meaning, or have some generations that are exceedingly short? + # The choice we go for here is to avoid truncating the prompt if we can, since it + # should have been managed by the prompt creator/few shot manager if requested by the user. + inputs = tokenized["input_ids"] + context_size = len(inputs[0]) + + # left truncate the inputs to the maximum length + if max_new_tokens is not None: + if context_size + max_new_tokens > self.max_length: + logger.warning( + f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens." + ) + context_size = self.max_length - max_new_tokens + if context_size < 0: + logger.critical( + f"{context_size=} is less than 0, either reduce the max_new_tokens or increase model max length." + ) + raise ValueError("Context size is less than 0.") + inputs = [input[-context_size:] for input in inputs] + else: + if context_size > self.max_length: + logger.warning( + f"{context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens." + ) + context_size = self.max_length + inputs = [input[-context_size:] for input in inputs] + + _outputs = self._generate( + inputs=inputs, + max_new_tokens=max_new_tokens, + stop_tokens=stop_tokens, + returns_logits=returns_logits, + num_samples=num_samples, + ) + + for req_id, _output in _outputs.items(): + output_token_ids = [] + logprobs_raw = [] + result = [] + + # for output in _output.outputs: + output_token_ids.append(_output.static_outputs) + # logprobs_raw.append(output.logprobs) + result.append(self.tokenizer.decode(_output.static_outputs)) + + if logprobs_raw and output_token_ids and False: + logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]] + else: + logprobs = [] + + input_token_ids = _output.full_prompt_ids + cur_response = GenerativeResponse( + result=result, + logits=logprobs, + generated_tokens=output_token_ids, + input_tokens=input_token_ids, + ) + results.append(cur_response) + + return dataset.get_original_order(results) + + def _padded_greedy_until( self, requests: list[GreedyUntilRequest], ) -> list[GenerativeResponse]: @@ -625,12 +733,41 @@ def greedy_until( returns_logits=returns_logits, num_samples=num_samples, do_sample=do_sample, + use_fast=False, ) results.extend(cur_reponses) return dataset.get_original_order(results) - def _generate( + def greedy_until( + self, + requests: list[GreedyUntilRequest], + use_fast: bool = True, + ) -> list[GenerativeResponse]: + if use_fast: + return self._continious_greedy_until(requests) + else: + return self._padded_greedy_until(requests) + + def _generate_fast( + self, + inputs: list[list[int]], + max_new_tokens: Optional[int] = None, + stop_tokens: Optional[list[str]] = None, + returns_logits: Optional[bool] = False, + num_samples: int = 1, + generate: bool = True, + ) -> Dict[str, GenerativeResponse]: + # Compute model generation + batch_outputs = self.model.generate_batch( + inputs=inputs, + generation_config=self.model.generation_config, + # You can pass request-specific overrides here, e.g., max_new_tokens=100 + ) + + return batch_outputs + + def _generate_padded( self, batch: Batch, max_new_tokens: int, @@ -711,6 +848,16 @@ def _generate( return all_responses + def _generate( + self, + use_fast: bool = True, + **kwargs, + ) -> list[GenerativeResponse]: + if use_fast: + return self._generate_fast(**kwargs) + else: + return self._generate_padded(**kwargs) + def loglikelihood( self, requests: list[LoglikelihoodRequest],