initial commit
This commit is contained in:
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
61
utils/conversation.py
Normal file
61
utils/conversation.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
def load_conversation(model_id):
|
||||
folder = 'conversations/' + model_id.replace('/', '_')
|
||||
mkdir('../conversations')
|
||||
mkdir(folder)
|
||||
|
||||
files = os.listdir(folder)
|
||||
files = [file for file in files if file.endswith(".pickle") and os.path.isfile(folder + '/' + file)]
|
||||
files.sort(reverse=True)
|
||||
if len(files) > 0:
|
||||
pickle_filename = folder + '/' + files[0]
|
||||
print('Loading last conversation from ' + pickle_filename)
|
||||
with open(pickle_filename, 'rb') as file:
|
||||
return pickle.load(file)
|
||||
return []
|
||||
|
||||
|
||||
def save_conversation(model_id, messages):
|
||||
folder = 'conversations/' + model_id.replace('/', '_')
|
||||
mkdir('../conversations')
|
||||
mkdir(folder)
|
||||
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
||||
pickle_filename = folder + '/' + timestamp + '.pickle'
|
||||
with open(pickle_filename, 'wb') as file:
|
||||
pickle.dump(messages, file)
|
||||
|
||||
|
||||
def load_conversation_json(model_id):
|
||||
folder = 'conversations/' + model_id.replace('/', '_')
|
||||
mkdir('../conversations')
|
||||
mkdir(folder)
|
||||
|
||||
files = os.listdir(folder)
|
||||
files = [file for file in files if file.endswith(".json") and os.path.isfile(folder + '/' + file)]
|
||||
files.sort(reverse=True)
|
||||
if len(files) > 0:
|
||||
pickle_filename = folder + '/' + files[0]
|
||||
print('Loading last conversation from ' + pickle_filename)
|
||||
with open(pickle_filename, 'r') as file:
|
||||
return json.load(file)
|
||||
return []
|
||||
|
||||
|
||||
def save_conversation_json(model_id, messages):
|
||||
folder = 'conversations/' + model_id.replace('/', '_')
|
||||
mkdir('../conversations')
|
||||
mkdir(folder)
|
||||
timestamp = datetime.datetime.utcnow().strftime('%Y%m%d%H%M%S')
|
||||
pickle_filename = folder + '/' + timestamp + '.json'
|
||||
with open(pickle_filename, 'w') as file:
|
||||
json.dump(messages, file)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.isdir(path):
|
||||
os.mkdir(path)
|
||||
12
utils/download_dataset.py
Normal file
12
utils/download_dataset.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import pickle
|
||||
from conversation import mkdir
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset_id = 'OpenAssistant/oasst2'
|
||||
|
||||
mkdir('../datasets')
|
||||
pickle_filename = './datasets/' + dataset_id.replace('/', '_') + '.pickle'
|
||||
dataset = load_dataset(dataset_id)
|
||||
with open(pickle_filename, 'wb') as file:
|
||||
pickle.dump(dataset, file)
|
||||
print('Saved as pickle to ' + pickle_filename)
|
||||
18
utils/download_model.py
Normal file
18
utils/download_model.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer
|
||||
|
||||
from conversation import mkdir
|
||||
|
||||
model_id = 'Qwen/Qwen1.5-0.5B-Chat'
|
||||
# model_id = 'Qwen/Qwen1.5-1.8B-Chat'
|
||||
|
||||
print('Downloading ' + model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype='auto', device_map='auto')
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
print('Downloaded')
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
mkdir('models')
|
||||
trainer.save_model('./models/' + model_id.replace('/', '_'))
|
||||
3
utils/fix_cuda.sh
Executable file
3
utils/fix_cuda.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#https://medium.com/@Spritan/dealing-with-cuda-initialization-error-aa7c88d021e4
|
||||
sudo rmmod nvidia_uvm
|
||||
sudo modprobe nvidia_uvm
|
||||
22
utils/pickle2json.py
Normal file
22
utils/pickle2json.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
files = sys.argv[1:]
|
||||
print(files)
|
||||
|
||||
for pickle_filename in files:
|
||||
if not pickle_filename.endswith('.pickle'):
|
||||
print(pickle_filename + ' is not a pickle. ignoring')
|
||||
continue
|
||||
|
||||
with open(pickle_filename, 'rb') as file:
|
||||
obj = pickle.load(file)
|
||||
print(obj)
|
||||
|
||||
json_filename = pickle_filename[0:-6] + 'json'
|
||||
try:
|
||||
with open(json_filename, 'w') as file:
|
||||
json.dump(obj, file)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
14
utils/prompt.py
Normal file
14
utils/prompt.py
Normal file
@@ -0,0 +1,14 @@
|
||||
def prompt(prompt):
|
||||
while True:
|
||||
try:
|
||||
return input(prompt)
|
||||
except EOFError:
|
||||
print()
|
||||
exit(0)
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
exit(0)
|
||||
# in case: UnicodeDecodeError: 'utf-8' codec can't decode byte 0xc3 in position 11: invalid continuation byte
|
||||
except UnicodeDecodeError as e:
|
||||
print(e)
|
||||
print('prompt ignored')
|
||||
25
utils/split_shuffle_dataset.py
Normal file
25
utils/split_shuffle_dataset.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
original = sys.argv[1]
|
||||
no_dataset = int(sys.argv[2])
|
||||
|
||||
if not original.endswith('.jsonl') or not os.path.isfile(original):
|
||||
print('Not a jsonl file')
|
||||
exit(1)
|
||||
|
||||
out_dir = os.path.dirname(os.path.abspath(original))
|
||||
|
||||
with open(original, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
random.shuffle(lines)
|
||||
|
||||
for i in range(no_dataset):
|
||||
l = int(i * len(lines) / no_dataset)
|
||||
u = int((i + 1) * len(lines) / no_dataset)
|
||||
out_filename = os.path.basename(original)[0:-6].replace('_all', '_' + str(i)) + '.jsonl'
|
||||
with open(out_dir + '/' + out_filename, 'w') as f:
|
||||
f.writelines(lines[l:u])
|
||||
Reference in New Issue
Block a user