From d8c3e48eb51b21b8e27e8183cbafbde0954a8c13 Mon Sep 17 00:00:00 2001 From: cboulanger <info@bibliograph.org> Date: Mon, 4 Mar 2024 17:50:48 +0100 Subject: [PATCH] Updates --- mlx/lora/.gitignore | 1 + mlx/lora/finetune-experiments.ipynb | 295 ++++++++++++++++--- mlx/lora/finetune-reference-extraction.ipynb | 65 ++++ mlx/lora/lib/reference_extraction.py | 184 ++++++++++++ 4 files changed, 501 insertions(+), 44 deletions(-) create mode 100644 mlx/lora/finetune-reference-extraction.ipynb create mode 100644 mlx/lora/lib/reference_extraction.py diff --git a/mlx/lora/.gitignore b/mlx/lora/.gitignore index b859ba7..a9edd12 100644 --- a/mlx/lora/.gitignore +++ b/mlx/lora/.gitignore @@ -1 +1,2 @@ mlx_models +*.npz \ No newline at end of file diff --git a/mlx/lora/finetune-experiments.ipynb b/mlx/lora/finetune-experiments.ipynb index 7d7ab0e..00872e0 100644 --- a/mlx/lora/finetune-experiments.ipynb +++ b/mlx/lora/finetune-experiments.ipynb @@ -126,8 +126,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-01T08:54:54.448616Z", - "start_time": "2024-03-01T08:54:54.441906Z" + "end_time": "2024-03-02T14:33:04.365287Z", + "start_time": "2024-03-02T14:33:04.358090Z" } }, "id": "b4be7c0872d2fd34" @@ -223,25 +223,25 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 20, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Length of generated sequences:\n", - " - max: 5550\n", - " - avg: 2259.182608695652\n", + " - max: 5446\n", + " - avg: 2202.035087719298\n", "Longest sequences:\n", - "DivRuW: 5550\n", - "JurBüro: 5051\n", - "AVR: 4366\n", - "APR: 4350\n", - "AusR: 4244\n", + "FoR: 5446\n", + "DÖD: 4559\n", + "GLJ: 4153\n", "BKK: 4078\n", - "DÖD: 3818\n", - "EuZW: 3786\n", + "AcP: 3960\n", + "AuA: 3656\n", "HRN: 3467\n", + "DSB: 3433\n", + "DivRuW: 3360\n", "AuAS: 3272\n" ] } @@ -277,7 +277,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "start_time": "2024-03-01T08:54:56.039305Z" + "end_time": "2024-03-01T18:09:22.948780Z", + "start_time": "2024-03-01T18:09:21.794696Z" } }, "id": "31a2389404720256" @@ -406,7 +407,7 @@ { "cell_type": "markdown", "source": [ - "### Test the model with adapter" + "### Testing" ], "metadata": { "collapsed": false @@ -646,7 +647,9 @@ "source": [ "## mlx-community/quantized-gemma-7b-it\n", "\n", - "This model can be directly downloaded from HF, no conversion necessary" + "This model can be directly downloaded from HF, no conversion necessary\n", + "\n", + "based on https://gist.github.com/alexweberk/635431b5c5773efd6d1755801020429f" ], "metadata": { "collapsed": false @@ -744,9 +747,9 @@ "\n", "os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n", "prompt = f\"\"\"\n", - "#### instructions\n", + "#### Instructions\n", "{system_message}\n", - "### user\n", + "### User\n", "{instruction}\n", "{example}\n", "{epilog}\n", @@ -773,9 +776,8 @@ { "cell_type": "markdown", "source": [ - "### Generate training, testing and validation files\n", - "\n", - "based on https://gist.github.com/alexweberk/635431b5c5773efd6d1755801020429f" + "### Generate dataset\n", + "\n" ], "metadata": { "collapsed": false @@ -784,47 +786,47 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 21, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Length of generated sequences:\n", - " - max: 5107\n", - " - avg: 1976.0964912280701\n", + " - max: 4816\n", + " - avg: 2273.424778761062\n", "Longest sequences:\n", - "FoR: 5107\n", - "DivRuW: 5097\n", - "AfP: 4519\n", - "StAZ: 4418\n", - "DÖD: 4220\n", - "ECFR: 3519\n", - "APR: 3445\n", - "CB: 3387\n", - "AuA: 3317\n", - "HRN: 3128\n" + "JurBüro: 4816\n", + "AusR: 4697\n", + "AcP: 4010\n", + "AW-Prax: 3880\n", + "DÖD: 3870\n", + "DivRuW: 3867\n", + "AuA: 3706\n", + "StAZ: 3601\n", + "ANA-ZAR: 3361\n", + "AuAS: 3322\n" ] } ], "source": [ "from lib.prepare_training_data import create_training_file\n", "\n", - "prompt = f\"\"\"\n", - "# instructions\n", + "gemma_instruction = f\"\"\"\n", + "# Instructions\n", "{system_message}\n", - "# user\n", + "# User\n", "{instruction}\n", "{epilog}'\n", "\"\"\".strip()\n", "\n", - "def template_fn(prompt: str, answer: str):\n", - " return f'<bos><start_of_turn>user\\n{prompt}<end_of_turn>\\n<start_of_turn>model\\n{answer}<end_of_turn><eos>'\n", + "def template_fn(instruction: str, content: str, answer: str):\n", + " return f'<bos><start_of_turn>user\\n{instruction}\\n\\n{content}<end_of_turn>\\n<start_of_turn>model\\n{answer}<end_of_turn><eos>'\n", "\n", - "create_training_file(instruction=instruction,\n", + "create_training_file(instruction=gemma_instruction,\n", " template_func=template_fn,\n", " input_file='data/editors/editors.csv',\n", - " output_dir='data/editors-gemma',\n", + " output_dir='data/editors/gemma',\n", " content_dir='data/editors/website-data',\n", " max_chars=6000, max_gt_items=5,\n", " record_identifier_col=\"journal_abbr\",\n", @@ -835,21 +837,226 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-01T08:11:37.452870Z", - "start_time": "2024-03-01T08:11:32.919339Z" + "end_time": "2024-03-01T18:10:08.910617Z", + "start_time": "2024-03-01T18:10:07.794209Z" } }, "id": "8d61e8cf63aa5965" }, + { + "cell_type": "markdown", + "source": [ + "### Finetuning" + ], + "metadata": { + "collapsed": false + }, + "id": "55ac22c3a4e1305e" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading pretrained model\r\n", + "Fetching 8 files: 100%|██████████████████████████| 8/8 [00:00<00:00, 389.29it/s]\r\n", + "Total parameters 1998.171M\r\n", + "Trainable parameters 0.459M\r\n", + "Loading datasets\r\n", + "Training\r\n", + "Starting training..., iters: 600\r\n", + "Iter 1: Val loss 6.064, Val took 144.925s\r\n", + "Iter 10: Train loss 5.106, Learning Rate 1.000e-05, It/sec 0.085, Tokens/sec 60.820, Trained Tokens 7127\r\n", + "Iter 20: Train loss 4.203, Learning Rate 1.000e-05, It/sec 0.063, Tokens/sec 49.259, Trained Tokens 14895\r\n", + "Iter 30: Train loss 3.695, Learning Rate 1.000e-05, It/sec 0.086, Tokens/sec 51.924, Trained Tokens 20904\r\n", + "Iter 40: Train loss 3.638, Learning Rate 1.000e-05, It/sec 0.125, Tokens/sec 60.230, Trained Tokens 25739\r\n", + "Iter 50: Train loss 3.091, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 51.838, Trained Tokens 32501\r\n", + "Iter 60: Train loss 2.691, Learning Rate 1.000e-05, It/sec 0.086, Tokens/sec 56.203, Trained Tokens 39053\r\n", + "Iter 70: Train loss 2.525, Learning Rate 1.000e-05, It/sec 0.061, Tokens/sec 45.969, Trained Tokens 46546\r\n", + "Iter 80: Train loss 2.313, Learning Rate 1.000e-05, It/sec 0.096, Tokens/sec 55.637, Trained Tokens 52342\r\n", + "Iter 90: Train loss 1.986, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 51.038, Trained Tokens 58944\r\n", + "Iter 100: Train loss 1.932, Learning Rate 1.000e-05, It/sec 0.088, Tokens/sec 56.251, Trained Tokens 65370\r\n", + "Iter 100: Saved adapter weights to checkpoints/100_editors.npz.\r\n", + "Iter 110: Train loss 1.745, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 49.590, Trained Tokens 72183\r\n", + "Iter 120: Train loss 1.531, Learning Rate 1.000e-05, It/sec 0.108, Tokens/sec 57.670, Trained Tokens 77506\r\n", + "Iter 130: Train loss 1.817, Learning Rate 1.000e-05, It/sec 0.090, Tokens/sec 56.202, Trained Tokens 83737\r\n", + "Iter 140: Train loss 1.358, Learning Rate 1.000e-05, It/sec 0.099, Tokens/sec 59.363, Trained Tokens 89733\r\n", + "Iter 150: Train loss 1.517, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 49.477, Trained Tokens 96513\r\n", + "Iter 160: Train loss 1.524, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 49.295, Trained Tokens 102887\r\n", + "Iter 170: Train loss 1.316, Learning Rate 1.000e-05, It/sec 0.071, Tokens/sec 50.801, Trained Tokens 109993\r\n", + "Iter 180: Train loss 1.440, Learning Rate 1.000e-05, It/sec 0.092, Tokens/sec 55.542, Trained Tokens 116060\r\n", + "Iter 190: Train loss 1.402, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 53.860, Trained Tokens 123028\r\n", + "Iter 200: Train loss 1.578, Learning Rate 1.000e-05, It/sec 0.069, Tokens/sec 53.572, Trained Tokens 130758\r\n", + "Iter 200: Val loss 1.449, Val took 136.326s\r\n", + "Iter 200: Saved adapter weights to checkpoints/200_editors.npz.\r\n", + "Iter 210: Train loss 1.403, Learning Rate 1.000e-05, It/sec 0.068, Tokens/sec 48.232, Trained Tokens 137872\r\n", + "Iter 220: Train loss 1.596, Learning Rate 1.000e-05, It/sec 0.079, Tokens/sec 55.729, Trained Tokens 144886\r\n", + "Iter 230: Train loss 1.324, Learning Rate 1.000e-05, It/sec 0.097, Tokens/sec 57.306, Trained Tokens 150820\r\n", + "Iter 240: Train loss 1.308, Learning Rate 1.000e-05, It/sec 0.090, Tokens/sec 55.030, Trained Tokens 156927\r\n", + "Iter 250: Train loss 1.394, Learning Rate 1.000e-05, It/sec 0.087, Tokens/sec 58.270, Trained Tokens 163638\r\n", + "Iter 260: Train loss 1.424, Learning Rate 1.000e-05, It/sec 0.085, Tokens/sec 57.969, Trained Tokens 170473\r\n", + "Iter 270: Train loss 1.299, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 42.263, Trained Tokens 175687\r\n", + "Iter 280: Train loss 1.489, Learning Rate 1.000e-05, It/sec 0.055, Tokens/sec 45.052, Trained Tokens 183919\r\n", + "Iter 290: Train loss 1.248, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 53.325, Trained Tokens 190540\r\n", + "Iter 300: Train loss 1.277, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 52.184, Trained Tokens 197354\r\n", + "Iter 300: Saved adapter weights to checkpoints/300_editors.npz.\r\n", + "Iter 310: Train loss 1.346, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 57.205, Trained Tokens 204429\r\n", + "Iter 320: Train loss 1.311, Learning Rate 1.000e-05, It/sec 0.087, Tokens/sec 51.852, Trained Tokens 210397\r\n", + "Iter 330: Train loss 1.398, Learning Rate 1.000e-05, It/sec 0.085, Tokens/sec 54.920, Trained Tokens 216829\r\n", + "Iter 340: Train loss 1.370, Learning Rate 1.000e-05, It/sec 0.091, Tokens/sec 57.985, Trained Tokens 223184\r\n", + "Iter 350: Train loss 1.099, Learning Rate 1.000e-05, It/sec 0.088, Tokens/sec 55.375, Trained Tokens 229472\r\n", + "Iter 360: Train loss 1.325, Learning Rate 1.000e-05, It/sec 0.061, Tokens/sec 44.221, Trained Tokens 236775\r\n", + "Iter 370: Train loss 1.140, Learning Rate 1.000e-05, It/sec 0.095, Tokens/sec 56.221, Trained Tokens 242704\r\n", + "Iter 380: Train loss 1.195, Learning Rate 1.000e-05, It/sec 0.091, Tokens/sec 53.397, Trained Tokens 248601\r\n", + "Iter 390: Train loss 1.655, Learning Rate 1.000e-05, It/sec 0.057, Tokens/sec 49.824, Trained Tokens 257341\r\n", + "Iter 400: Train loss 1.144, Learning Rate 1.000e-05, It/sec 0.112, Tokens/sec 59.800, Trained Tokens 262675\r\n", + "Iter 400: Val loss 1.406, Val took 141.987s\r\n", + "Iter 400: Saved adapter weights to checkpoints/400_editors.npz.\r\n", + "Iter 410: Train loss 1.220, Learning Rate 1.000e-05, It/sec 0.090, Tokens/sec 55.037, Trained Tokens 268819\r\n", + "Iter 420: Train loss 1.454, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 51.010, Trained Tokens 275828\r\n", + "Iter 430: Train loss 1.195, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 55.343, Trained Tokens 282670\r\n", + "Iter 440: Train loss 1.251, Learning Rate 1.000e-05, It/sec 0.078, Tokens/sec 49.881, Trained Tokens 289097\r\n", + "Iter 450: Train loss 1.176, Learning Rate 1.000e-05, It/sec 0.080, Tokens/sec 50.116, Trained Tokens 295391\r\n", + "Iter 460: Train loss 1.249, Learning Rate 1.000e-05, It/sec 0.070, Tokens/sec 48.412, Trained Tokens 302328\r\n", + "Iter 470: Train loss 1.134, Learning Rate 1.000e-05, It/sec 0.105, Tokens/sec 59.645, Trained Tokens 308020\r\n", + "Iter 480: Train loss 1.295, Learning Rate 1.000e-05, It/sec 0.075, Tokens/sec 52.581, Trained Tokens 315053\r\n", + "Iter 490: Train loss 1.182, Learning Rate 1.000e-05, It/sec 0.093, Tokens/sec 58.488, Trained Tokens 321354\r\n", + "Iter 500: Train loss 1.344, Learning Rate 1.000e-05, It/sec 0.082, Tokens/sec 52.577, Trained Tokens 327745\r\n", + "Iter 500: Saved adapter weights to checkpoints/500_editors.npz.\r\n", + "Iter 510: Train loss 1.183, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 52.934, Trained Tokens 335029\r\n", + "Iter 520: Train loss 1.449, Learning Rate 1.000e-05, It/sec 0.067, Tokens/sec 51.610, Trained Tokens 342699\r\n", + "Iter 530: Train loss 1.332, Learning Rate 1.000e-05, It/sec 0.078, Tokens/sec 50.195, Trained Tokens 349121\r\n", + "Iter 540: Train loss 1.293, Learning Rate 1.000e-05, It/sec 0.091, Tokens/sec 54.443, Trained Tokens 355075\r\n", + "Iter 550: Train loss 1.444, Learning Rate 1.000e-05, It/sec 0.063, Tokens/sec 49.679, Trained Tokens 363018\r\n", + "Iter 560: Train loss 1.197, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 50.547, Trained Tokens 369947\r\n", + "Iter 570: Train loss 1.257, Learning Rate 1.000e-05, It/sec 0.057, Tokens/sec 44.008, Trained Tokens 377730\r\n", + "Iter 580: Train loss 1.160, Learning Rate 1.000e-05, It/sec 0.075, Tokens/sec 50.617, Trained Tokens 384486\r\n", + "Iter 590: Train loss 1.214, Learning Rate 1.000e-05, It/sec 0.102, Tokens/sec 55.789, Trained Tokens 389956\r\n", + "Iter 600: Train loss 0.879, Learning Rate 1.000e-05, It/sec 0.137, Tokens/sec 61.512, Trained Tokens 394433\r\n", + "Iter 600: Val loss 1.367, Val took 137.833s\r\n", + "Iter 600: Saved adapter weights to checkpoints/600_editors.npz.\r\n", + "Saved final adapter weights to editors.npz.\r\n" + ] + } + ], + "source": [ + "from mlx_lm.utils import get_model_path\n", + "import os\n", + "os.environ['MODEL_NAME'] = model_name = 'mlx-community/quantized-gemma-7b-it'\n", + "\n", + " \n", + "!python -m mlx_lm.lora \\\n", + " --model \"$MODEL_NAME\" \\\n", + " --adapter-file \"editors.npz\" \\\n", + " --train \\\n", + " --iters 600 --batch-size 1 --lora-layers 4 \\\n", + " --data data/editors/gemma" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-02T19:40:14.555110Z", + "start_time": "2024-03-02T15:14:52.368721Z" + } + }, + "id": "a9da1ee7b6cc7997" + }, + { + "cell_type": "markdown", + "source": [ + "Iter 600: Val loss 1.367, Val took 137.833s" + ], + "metadata": { + "collapsed": false + }, + "id": "7d5ea07d81750fed" + }, + { + "cell_type": "markdown", + "source": [ + "### Testing" + ], + "metadata": { + "collapsed": false + }, + "id": "57f675d7cdfd2965" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "Fetching 8 files: 0%| | 0/8 [00:00<?, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "eceea48ae37a4b75be5f13369be5ff9f" + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading pretrained model\r\n", + "Fetching 8 files: 100%|████████████████████████| 8/8 [00:00<00:00, 42799.02it/s]\r\n", + "Total parameters 1999.547M\r\n", + "Trainable parameters 1.835M\r\n", + "Loading datasets\r\n", + "Testing\r\n", + "Test loss 1.395, Test ppl 4.035.\r\n" + ] + } + ], + "source": [ + "os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n", + "os.environ['MODEL_PATH'] = str(get_model_path(model_name))\n", + "!python -m mlx_lm.lora \\\n", + " --model \"$MODEL_NAME\" \\\n", + " --adapter-file \"editors.npz\" \\\n", + " --data data/editors/gemma \\\n", + " --test" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-03T10:30:03.579158Z", + "start_time": "2024-03-03T10:21:41.778749Z" + } + }, + "id": "de978facc2c3c978" + }, + { + "cell_type": "markdown", + "source": [ + "Test loss 1.395, Test ppl 4.035." + ], + "metadata": { + "collapsed": false + }, + "id": "a39b638de1705a79" + }, { "cell_type": "code", "execution_count": null, "outputs": [], - "source": [], + "source": [ + "# Load the fine-tuned model with LoRA weights\n", + "model_lora, _ = load(\n", + " \"mlx-community/quantized-gemma-7b-it\",\n", + " adapter_file=\"./editors.npz\", # adapters.npz is the final checkpoint saved at the end of training\n", + ")" + ], "metadata": { "collapsed": false }, - "id": "db51ef32ff18dff3" + "id": "7a0a937fa126c7ad" } ], "metadata": { diff --git a/mlx/lora/finetune-reference-extraction.ipynb b/mlx/lora/finetune-reference-extraction.ipynb new file mode 100644 index 0000000..af40444 --- /dev/null +++ b/mlx/lora/finetune-reference-extraction.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "initial_id", + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Reference extraction with Google Gemma\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [ + "from lib.reference_extraction import load_xmls_from_directory\n", + "\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-03-02T15:26:03.329751Z", + "start_time": "2024-03-02T15:25:58.279349Z" + } + }, + "id": "6e7ba19d6cf1d66b" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "92cd78ba5d9944c4" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mlx/lora/lib/reference_extraction.py b/mlx/lora/lib/reference_extraction.py new file mode 100644 index 0000000..aa5073a --- /dev/null +++ b/mlx/lora/lib/reference_extraction.py @@ -0,0 +1,184 @@ +# generated by GPT-4 +import re +import pandas as pd +import os +import yaml +import json +from sklearn.model_selection import train_test_split +import numpy as np +import textwrap +from lxml import etree + + +def xml_to_dataframe(xml_path): + # Parse the XML file + tree = etree.parse(xml_path) + root = tree.getroot() + + # Initialize an empty list to store row data + rows = [] + + # Iterate through each sequence in the XML + for sequence in root.findall('sequence'): + # Initialize an empty dictionary for each sequence + row_data = {} + + # Iterate through each child element in the sequence + for element in sequence: + # Use tag name as column name and text as the value in the row_data dictionary + row_data[element.tag] = element.text + + # Append the dictionary to the rows list + rows.append(row_data) + + # Convert the list of dictionaries to a DataFrame + df = pd.DataFrame(rows) + + # Return the DataFrame + return df + +import os + +def load_xmls_from_directory(directory_path): + # Initialize an empty DataFrame to store all data + all_data = pd.DataFrame() + + # Iterate over every file in the directory + for filename in os.listdir(directory_path): + # Construct the full file path + file_path = os.path.join(directory_path, filename) + + # Check if the file is an XML file + if os.path.isfile(file_path) and file_path.endswith('.xml'): + # Load the XML file into a DataFrame + df = xml_to_dataframe(file_path) + + # Append the data to the all_data DataFrame + all_data = pd.concat([all_data, df], ignore_index=True) + + # Return the combined DataFrame + return all_data + + + +def clean_values(row): + """Return a new dictionary with no NaN, None, or empty/whitespace-only values.""" + return {k: v for k, v in row.items() if pd.notna(v) and v is not None and v.strip()} + +def create_training_files(instruction: str, + template_func: callable, + input_file: str, output_dir: str, content_dir: str, + cols_to_remove: list, column_to_filter_by: str, + record_identifier_col: str, + max_chars=2048 * 4, max_gt_items=10, + line_width=120, lines_before=1, lines_after=1, random_seed=None, + debug=True): + + + # Load the input + df = pd.read_csv(input_file) + + # Set or generate a random seed + if random_seed is None: + random_seed = np.random.randint(0, 10000) + + # Prepare the output directory + os.makedirs(output_dir, exist_ok=True) + + # Keep track of the longest prompts + id_value_pairs = {} + + def process_and_write_data(grouped_df, file, write_debug_file=False): + for id, group in grouped_df.groupby(record_identifier_col): + + # use only the first 10 ground truth items so that the training record does not become too large + group = group.head(max_gt_items) + + # load website content from the cache + filename = f"{content_dir}/{re.sub(r'[. ]', '', str(id))}.txt" + try: + with open(filename, 'r', encoding='utf-8') as content_file: + content = content_file.read() + except FileNotFoundError: + continue + + # Clean up data + answer_df = group.drop(cols_to_remove, axis=1, errors='ignore').dropna(axis=1, how='all') + + # Convert DataFrame rows to YAML dictionaries, cleaning NaN values + cleaned_rows = answer_df.apply(lambda row: clean_values(row.to_dict()), axis=1) + answer = yaml.dump(list(cleaned_rows), allow_unicode=True, sort_keys=False) + + # wrap content to decrease token window + lines = [line for line in wrap_content_generator(content, width=line_width)] + wrapped_content = '\n'.join(lines) + + # Determine how many characters are available for the content + max_chars_for_content = max_chars - len(instruction) - len(answer) + + # Determine which keywords to filter by + keyword_list = [x for x in group[column_to_filter_by].tolist() if pd.notnull(x)] + + # Filter rows, using keywords and a context window + filtered_content = filter_content_with_context(wrapped_content, + keywords=keyword_list, + lines_before=lines_before, lines_after=lines_after, + max_chars=max_chars_for_content) + + # Ignore if nothing was found + if filtered_content.strip() == '': + continue + + if write_debug_file: + sequence = [ + '### JOURNAL', id, + '### URL', group['website'].tolist()[0], + '### CONTENT', filtered_content, + '### ANSWER', answer, + '-' * line_width, + '' + ] + file.write('\n\n'.join(sequence)) + else: + sequence = template_func(instruction, filtered_content, answer) + train_json = { + "id": id, + "text": sequence + } + file.write(json.dumps(train_json) + '\n') + id_value_pairs[id] = len(sequence) + + # Split the DataFrame into training (80%) and test+validation (20%) + train_df, test_valid_df = train_test_split(df, test_size=0.2, random_state=random_seed) + + # Further split the test+validation into test and validation (50% each of the 20%) + test_df, valid_df = train_test_split(test_valid_df, test_size=0.5, random_state=random_seed) + + # Debug file for better readability of the result + if debug: + debug_instruction = instruction.replace('\n', '\n\n') + debug_instruction = "\n".join(textwrap.wrap(debug_instruction, width=line_width)) + with open(f'{output_dir}/debug.txt', 'w', encoding='utf-8') as debug_file: + debug_file.write(f'### INSTRUCTION\n\n{debug_instruction}\n\n{"=" * line_width}\n\n') + process_and_write_data(df, debug_file, write_debug_file=True) + + # Write train, test and validation files + with open(f'{output_dir}/train.jsonl', 'w', encoding='utf-8') as train_file, \ + open(f'{output_dir}/test.jsonl', 'w', encoding='utf-8') as test_file, \ + open(f'{output_dir}/valid.jsonl', 'w', encoding='utf-8') as valid_file: + + # Process and write data to each file + process_and_write_data(train_df, train_file) + process_and_write_data(test_df, test_file) + process_and_write_data(valid_df, valid_file) + + if debug: + print("Length of generated sequences:") + print(f" - max: {max(id_value_pairs.values())}") + print(f" - avg: {sum(id_value_pairs.values()) / len(id_value_pairs)}") + + sorted_pairs = sorted(id_value_pairs.items(), key=lambda x: x[1], reverse=True) + highest_10_values = sorted_pairs[:10] + print("Longest sequences:") + for _id, value in highest_10_values: + print(f"{_id}: {value}") -- GitLab