From adcb172da4c819e8d801a9c225c202820e600fa9 Mon Sep 17 00:00:00 2001 From: Florin Tobler Date: Sat, 4 Jan 2025 16:07:31 +0100 Subject: [PATCH] fix cache --- inference.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/inference.py b/inference.py index 7ff99a9..9871640 100644 --- a/inference.py +++ b/inference.py @@ -73,7 +73,8 @@ class Inference: def generate(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: with torch.inference_mode(): - return self.generate_incremental_2(input_ids) + with torch.no_grad(): + return self.generate_incremental_2(input_ids) def generate_batch(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: @@ -96,24 +97,22 @@ class Inference: def generate_incremental_2(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: generated_tokens = input_ids - # past_key_values = DynamicCache() - past_key_values = StaticCache(config=self.model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16) + past_key_values = DynamicCache() # n = 0 try: while True: outputs = self.model.generate( generated_tokens, # **inputs, inputs["input_ids"] - max_new_tokens=10, # like streaming + max_new_tokens=5, # like streaming pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, do_sample=True, num_return_sequences=1, num_beams = 1, - use_cache=True, + # use_cache=True, past_key_values=past_key_values ) - # past_key_values = outputs.past_key_values # Get the next token (the last token from the generated sequence) # next_token = outputs.argmax(dim=-1)[:, -1]