From 91c7ee993d07a05951e19ac12e8be2c3c9442b2d Mon Sep 17 00:00:00 2001 From: Florin Tobler Date: Thu, 2 Jan 2025 04:17:46 +0100 Subject: [PATCH] add prefix extension --- inference.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/inference.py b/inference.py index 41b418c..022b451 100644 --- a/inference.py +++ b/inference.py @@ -121,13 +121,26 @@ class Inference: return generated_tokens, full_output - def tokenize(self, messages: list[dict], tokenize: bool) -> str | torch.Tensor: + def tokenize(self, messages: list[dict], tokenize: bool, assistant_prefix: str = None) -> str | torch.Tensor: if tokenize: inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True) #continue_final_message=True, inputs = {key: value.to(self.model.device) for key, value in inputs.items()} - return inputs["input_ids"] + input_ids = inputs["input_ids"] + + # Append the assistant prefix if provided + if assistant_prefix: + prefix_ids = self.tokenizer(assistant_prefix, return_tensors="pt")["input_ids"] + input_ids = torch.cat([input_ids, prefix_ids.to(self.model.device)], dim=-1) + + return input_ids else: + # only plain text generation message = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False) + + # Append the assistant prefix to raw text if provided + if assistant_prefix: + message += f"<|im_start|>assistant\n{assistant_prefix}" + return message