|
|
@ -73,7 +73,8 @@ class Inference: |
|
|
|
|
|
|
|
def generate(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: |
|
|
|
with torch.inference_mode(): |
|
|
|
return self.generate_incremental_2(input_ids) |
|
|
|
with torch.no_grad(): |
|
|
|
return self.generate_incremental_2(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
def generate_batch(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: |
|
|
@ -96,24 +97,22 @@ class Inference: |
|
|
|
def generate_incremental_2(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]: |
|
|
|
generated_tokens = input_ids |
|
|
|
|
|
|
|
# past_key_values = DynamicCache() |
|
|
|
past_key_values = StaticCache(config=self.model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16) |
|
|
|
past_key_values = DynamicCache() |
|
|
|
|
|
|
|
# n = 0 |
|
|
|
try: |
|
|
|
while True: |
|
|
|
outputs = self.model.generate( |
|
|
|
generated_tokens, # **inputs, inputs["input_ids"] |
|
|
|
max_new_tokens=10, # like streaming |
|
|
|
max_new_tokens=5, # like streaming |
|
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
|
do_sample=True, |
|
|
|
num_return_sequences=1, |
|
|
|
num_beams = 1, |
|
|
|
use_cache=True, |
|
|
|
# use_cache=True, |
|
|
|
past_key_values=past_key_values |
|
|
|
) |
|
|
|
# past_key_values = outputs.past_key_values |
|
|
|
|
|
|
|
# Get the next token (the last token from the generated sequence) |
|
|
|
# next_token = outputs.argmax(dim=-1)[:, -1] |
|
|
|