fix cache
This commit is contained in:
@@ -73,6 +73,7 @@ class Inference:
|
||||
|
||||
def generate(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||
with torch.inference_mode():
|
||||
with torch.no_grad():
|
||||
return self.generate_incremental_2(input_ids)
|
||||
|
||||
|
||||
@@ -96,24 +97,22 @@ class Inference:
|
||||
def generate_incremental_2(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||
generated_tokens = input_ids
|
||||
|
||||
# past_key_values = DynamicCache()
|
||||
past_key_values = StaticCache(config=self.model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
# n = 0
|
||||
try:
|
||||
while True:
|
||||
outputs = self.model.generate(
|
||||
generated_tokens, # **inputs, inputs["input_ids"]
|
||||
max_new_tokens=10, # like streaming
|
||||
max_new_tokens=5, # like streaming
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
do_sample=True,
|
||||
num_return_sequences=1,
|
||||
num_beams = 1,
|
||||
use_cache=True,
|
||||
# use_cache=True,
|
||||
past_key_values=past_key_values
|
||||
)
|
||||
# past_key_values = outputs.past_key_values
|
||||
|
||||
# Get the next token (the last token from the generated sequence)
|
||||
# next_token = outputs.argmax(dim=-1)[:, -1]
|
||||
|
Reference in New Issue
Block a user