diff --git a/train/prepare/helpsteer/helpsteer22jsonl.py b/train/prepare/helpsteer/helpsteer22jsonl.py index 248154b..3653d44 100644 --- a/train/prepare/helpsteer/helpsteer22jsonl.py +++ b/train/prepare/helpsteer/helpsteer22jsonl.py @@ -1,7 +1,9 @@ import json import os +import random -this_dir = os.path.dirname(os.path.abspath(__file__)) +user_names = ['Adam', 'Alice', 'Anne', 'Bob', 'Charlie', 'Cody', 'Corinna', 'Cynthia', 'Fred', 'Grace', 'Jane', 'Paul', + 'Rachel', 'Ramesh'] def mkdir(path): @@ -9,6 +11,7 @@ def mkdir(path): os.mkdir(path) +this_dir = os.path.dirname(os.path.abspath(__file__)) mkdir(this_dir + '/../../data') mkdir(this_dir + '/../../data/helpsteer') @@ -18,8 +21,10 @@ for filename in ['train.jsonl', 'validation.jsonl']: role_dict = {'prompt': 'user', 'response': 'assistant'} lines = [json.loads(line) for line in lines] - conversations = [{'messages': [{'role': 'user', 'content': line['prompt']}, - {'role': 'assistant', 'content': line['response']}]} for line in lines] + conversations = [{'messages': [ + {'role': 'user', 'name': user_names[random.randint(0, len(user_names) - 1)], 'content': line['prompt']}, + {'role': 'assistant', 'name': 'assistant', 'content': line['response']}] + } for line in lines] print(conversations[0]) diff --git a/train/prepare/oasst2/oasst22jsonl.py b/train/prepare/oasst2/oasst22jsonl.py index 6fde55b..f22748b 100644 --- a/train/prepare/oasst2/oasst22jsonl.py +++ b/train/prepare/oasst2/oasst22jsonl.py @@ -1,5 +1,7 @@ import json import os +import random +from typing import AnyStr # parsing OA data files with oasst_data helpers from oasst_data import read_message_trees, ExportMessageNode @@ -10,23 +12,28 @@ this_dir = os.path.dirname(os.path.abspath(__file__)) input_file_path = this_dir + '/2023-11-05_oasst2_all.trees.jsonl.gz' role_dict = {'prompter': 'user', 'assistant': 'assistant'} +user_names = ['Adam', 'Alice', 'Anne', 'Bob', 'Charlie', 'Cody', 'Corinna', 'Cynthia', 'Fred', 'Grace', 'Jane', 'Paul', + 'Rachel', 'Ramesh'] conversations = [] -def visit(node: ExportMessageNode, parents: [ExportMessageNode]): +def visit(node: ExportMessageNode, parents: [ExportMessageNode], user: AnyStr): new_parents = parents + [node] if not node.replies: # end of conversation - conversations.append({'messages': [{'role': role_dict[p.role], 'content': p.text} for p in new_parents]}) + conversations.append({'messages': [{'role': role_dict[p.role], + 'name': user if role_dict[p.role] != 'assistant' else 'assistant', + 'content': p.text + } for p in new_parents]}) else: for reply in node.replies: - visit(reply, new_parents) + visit(reply, new_parents, user) for tree in read_message_trees(input_file_path): if tree.prompt.lang not in ['en']: # filtering by language tag (optional) continue - visit(tree.prompt, []) + visit(tree.prompt, [], user_names[random.randint(0, len(user_names) - 1)]) print(conversations[0])