import torch
from transformers import GPTNeoForCausalLM, GPT2Tokenizer

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f'Using device {device}')

model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B").half().to(device)
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
input = input("prompt: ")
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(device)

num_new_tokens = 500
temperature = 0.95
TOP_P = 0.9
repetition_penalty = 3.0

gen_tokens = model.generate(
                input_ids,
                do_sample=True,
                temperature=temperature,
                repetition_penalty=repetition_penalty,
                top_p=TOP_P,
                max_new_tokens=num_new_tokens,
                pad_token_id=tokenizer.eos_token_id
            )

print(tokenizer.batch_decode(gen_tokens)[0])