fix cache
This commit is contained in:
11
inference.py
11
inference.py
@@ -73,7 +73,8 @@ class Inference:
|
|||||||
|
|
||||||
def generate(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
def generate(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
return self.generate_incremental_2(input_ids)
|
with torch.no_grad():
|
||||||
|
return self.generate_incremental_2(input_ids)
|
||||||
|
|
||||||
|
|
||||||
def generate_batch(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
def generate_batch(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||||
@@ -96,24 +97,22 @@ class Inference:
|
|||||||
def generate_incremental_2(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
def generate_incremental_2(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, str]:
|
||||||
generated_tokens = input_ids
|
generated_tokens = input_ids
|
||||||
|
|
||||||
# past_key_values = DynamicCache()
|
past_key_values = DynamicCache()
|
||||||
past_key_values = StaticCache(config=self.model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# n = 0
|
# n = 0
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
generated_tokens, # **inputs, inputs["input_ids"]
|
generated_tokens, # **inputs, inputs["input_ids"]
|
||||||
max_new_tokens=10, # like streaming
|
max_new_tokens=5, # like streaming
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
num_beams = 1,
|
num_beams = 1,
|
||||||
use_cache=True,
|
# use_cache=True,
|
||||||
past_key_values=past_key_values
|
past_key_values=past_key_values
|
||||||
)
|
)
|
||||||
# past_key_values = outputs.past_key_values
|
|
||||||
|
|
||||||
# Get the next token (the last token from the generated sequence)
|
# Get the next token (the last token from the generated sequence)
|
||||||
# next_token = outputs.argmax(dim=-1)[:, -1]
|
# next_token = outputs.argmax(dim=-1)[:, -1]
|
||||||
|
Reference in New Issue
Block a user