diff --git a/inference.py b/inference.py index 005c74d..1401f81 100644 --- a/inference.py +++ b/inference.py @@ -101,7 +101,7 @@ class Inference: if print_stdout: print(out_text) return outputs, out_text - + def generate_incremental_2(self, input_ids: torch.Tensor, print_stdout:bool=True) -> tuple[torch.Tensor, str]: generated_tokens = input_ids @@ -180,7 +180,7 @@ class Inference: while True: # Call the model with the current tokens outputs = self.model( - input_ids=generated_tokens, + input_ids=generated_tokens, use_cache=True, num_beams = 1 # past_key_values=past_key_values