add prefix extension
This commit is contained in:
17
inference.py
17
inference.py
@@ -121,13 +121,26 @@ class Inference:
|
|||||||
return generated_tokens, full_output
|
return generated_tokens, full_output
|
||||||
|
|
||||||
|
|
||||||
def tokenize(self, messages: list[dict], tokenize: bool) -> str | torch.Tensor:
|
def tokenize(self, messages: list[dict], tokenize: bool, assistant_prefix: str = None) -> str | torch.Tensor:
|
||||||
if tokenize:
|
if tokenize:
|
||||||
inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True) #continue_final_message=True,
|
inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, add_generation_prompt=True) #continue_final_message=True,
|
||||||
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
|
inputs = {key: value.to(self.model.device) for key, value in inputs.items()}
|
||||||
return inputs["input_ids"]
|
input_ids = inputs["input_ids"]
|
||||||
|
|
||||||
|
# Append the assistant prefix if provided
|
||||||
|
if assistant_prefix:
|
||||||
|
prefix_ids = self.tokenizer(assistant_prefix, return_tensors="pt")["input_ids"]
|
||||||
|
input_ids = torch.cat([input_ids, prefix_ids.to(self.model.device)], dim=-1)
|
||||||
|
|
||||||
|
return input_ids
|
||||||
else:
|
else:
|
||||||
|
# only plain text generation
|
||||||
message = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False)
|
message = self.tokenizer.apply_chat_template(messages, return_tensors="pt", tokenize=False, add_generation_prompt=False)
|
||||||
|
|
||||||
|
# Append the assistant prefix to raw text if provided
|
||||||
|
if assistant_prefix:
|
||||||
|
message += f"<|im_start|>assistant\n{assistant_prefix}"
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user