gpt2 tests
This commit is contained in:
24
gpt2.py
Normal file
24
gpt2.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import cProfile
|
||||
import pstats
|
||||
from transformers import pipeline
|
||||
import time
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(24) # Adjust this to the number of threads/cores you have
|
||||
|
||||
# Initialize the pipeline
|
||||
generator = pipeline('text-generation', model='gpt2', device_map="cpu") # gpt2
|
||||
|
||||
def run_inference():
|
||||
t_start = time.time()
|
||||
# Generate text
|
||||
generated_text = generator("below is a simple python function to extract email addresses from a string:", max_length=500, num_return_sequences=1)
|
||||
|
||||
# Print the generated text
|
||||
print(generated_text[0]['generated_text'])
|
||||
print("took %.3fs" % (time.time() - t_start))
|
||||
|
||||
cProfile.run('run_inference()', 'profile_output.prof')
|
||||
|
||||
p = pstats.Stats('profile_output.prof')
|
||||
p.sort_stats('cumulative').print_stats(30) # Show the top 10 time-consuming functions
|
Reference in New Issue
Block a user