Browse Source

fix cache

master
Florin Tobler 5 months ago
parent
commit
adcb172da4
  1. 11
      inference.py

11
inference.py

@ -73,7 +73,8 @@ class Inference:
def generate(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
with torch.inference_mode():
return self.generate_incremental_2(input_ids)
with torch.no_grad():
return self.generate_incremental_2(input_ids)
def generate_batch(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
@ -96,24 +97,22 @@ class Inference:
def generate_incremental_2(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
generated_tokens = input_ids
# past_key_values = DynamicCache()
past_key_values = StaticCache(config=self.model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
past_key_values = DynamicCache()
# n = 0
try:
while True:
outputs = self.model.generate(
generated_tokens, # **inputs, inputs["input_ids"]
max_new_tokens=10, # like streaming
max_new_tokens=5, # like streaming
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
num_return_sequences=1,
num_beams = 1,
use_cache=True,
# use_cache=True,
past_key_values=past_key_values
)
# past_key_values = outputs.past_key_values
# Get the next token (the last token from the generated sequence)
# next_token = outputs.argmax(dim=-1)[:, -1]

Loading…
Cancel
Save