From 30cce3842ea5fc17d6ffaed435df708a1b5f5a24 Mon Sep 17 00:00:00 2001 From: wea_ondara Date: Thu, 18 Apr 2024 15:51:28 +0200 Subject: [PATCH] also clear cuda cache on error --- chat_qwen.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/chat_qwen.py b/chat_qwen.py index 6cbdae5..d691275 100644 --- a/chat_qwen.py +++ b/chat_qwen.py @@ -26,25 +26,27 @@ class ChatQwen: print('Loaded') def generate(self, messages): - # prepare - messages = [m for m in messages if m['role'] != 'system'] - input_messages = [self.default_instruction] + messages + try: + # prepare + messages = [m for m in messages if m['role'] != 'system'] + input_messages = [self.default_instruction] + messages - # generate - text = self.tokenizer.apply_chat_template(input_messages, tokenize=False, add_generation_prompt=True) - model_inputs = self.tokenizer([text], return_tensors='pt').to(self.default_device) - generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=100) - generated_ids = [ - output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) - ] - response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + # generate + text = self.tokenizer.apply_chat_template(input_messages, tokenize=False, add_generation_prompt=True) + model_inputs = self.tokenizer([text], return_tensors='pt').to(self.default_device) + generated_ids = self.model.generate(model_inputs.input_ids, max_new_tokens=300) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] - # add response and save conversation - messages.append({'role': 'assistant', 'content': response}) - self.record_conversation(input_messages, {'role': 'assistant', 'content': response}) + # add response and save conversation + messages.append({'role': 'assistant', 'content': response}) + self.record_conversation(input_messages, {'role': 'assistant', 'content': response}) - torch.cuda.empty_cache() # clear cache or the gpu mem will be used a lot - return messages + return messages + finally: + torch.cuda.empty_cache() # clear cache or the gpu mem will be used a lot def record_conversation(self, messages, response): messages = messages + [response]