From d8c3e48eb51b21b8e27e8183cbafbde0954a8c13 Mon Sep 17 00:00:00 2001
From: cboulanger <info@bibliograph.org>
Date: Mon, 4 Mar 2024 17:50:48 +0100
Subject: [PATCH] Updates

---
 mlx/lora/.gitignore                          |   1 +
 mlx/lora/finetune-experiments.ipynb          | 295 ++++++++++++++++---
 mlx/lora/finetune-reference-extraction.ipynb |  65 ++++
 mlx/lora/lib/reference_extraction.py         | 184 ++++++++++++
 4 files changed, 501 insertions(+), 44 deletions(-)
 create mode 100644 mlx/lora/finetune-reference-extraction.ipynb
 create mode 100644 mlx/lora/lib/reference_extraction.py

diff --git a/mlx/lora/.gitignore b/mlx/lora/.gitignore
index b859ba7..a9edd12 100644
--- a/mlx/lora/.gitignore
+++ b/mlx/lora/.gitignore
@@ -1 +1,2 @@
 mlx_models
+*.npz
\ No newline at end of file
diff --git a/mlx/lora/finetune-experiments.ipynb b/mlx/lora/finetune-experiments.ipynb
index 7d7ab0e..00872e0 100644
--- a/mlx/lora/finetune-experiments.ipynb
+++ b/mlx/lora/finetune-experiments.ipynb
@@ -126,8 +126,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2024-03-01T08:54:54.448616Z",
-     "start_time": "2024-03-01T08:54:54.441906Z"
+     "end_time": "2024-03-02T14:33:04.365287Z",
+     "start_time": "2024-03-02T14:33:04.358090Z"
     }
    },
    "id": "b4be7c0872d2fd34"
@@ -223,25 +223,25 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 20,
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Length of generated sequences:\n",
-      " - max: 5550\n",
-      " - avg: 2259.182608695652\n",
+      " - max: 5446\n",
+      " - avg: 2202.035087719298\n",
       "Longest sequences:\n",
-      "DivRuW: 5550\n",
-      "JurBüro: 5051\n",
-      "AVR: 4366\n",
-      "APR: 4350\n",
-      "AusR: 4244\n",
+      "FoR: 5446\n",
+      "DÖD: 4559\n",
+      "GLJ: 4153\n",
       "BKK: 4078\n",
-      "DÖD: 3818\n",
-      "EuZW: 3786\n",
+      "AcP: 3960\n",
+      "AuA: 3656\n",
       "HRN: 3467\n",
+      "DSB: 3433\n",
+      "DivRuW: 3360\n",
       "AuAS: 3272\n"
      ]
     }
@@ -277,7 +277,8 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "start_time": "2024-03-01T08:54:56.039305Z"
+     "end_time": "2024-03-01T18:09:22.948780Z",
+     "start_time": "2024-03-01T18:09:21.794696Z"
     }
    },
    "id": "31a2389404720256"
@@ -406,7 +407,7 @@
   {
    "cell_type": "markdown",
    "source": [
-    "### Test the model with adapter"
+    "### Testing"
    ],
    "metadata": {
     "collapsed": false
@@ -646,7 +647,9 @@
    "source": [
     "## mlx-community/quantized-gemma-7b-it\n",
     "\n",
-    "This model can be directly downloaded from HF, no conversion necessary"
+    "This model can be directly downloaded from HF, no conversion necessary\n",
+    "\n",
+    "based on https://gist.github.com/alexweberk/635431b5c5773efd6d1755801020429f"
    ],
    "metadata": {
     "collapsed": false
@@ -744,9 +747,9 @@
     "\n",
     "os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n",
     "prompt = f\"\"\"\n",
-    "#### instructions\n",
+    "#### Instructions\n",
     "{system_message}\n",
-    "### user\n",
+    "### User\n",
     "{instruction}\n",
     "{example}\n",
     "{epilog}\n",
@@ -773,9 +776,8 @@
   {
    "cell_type": "markdown",
    "source": [
-    "### Generate training, testing and validation files\n",
-    "\n",
-    "based on https://gist.github.com/alexweberk/635431b5c5773efd6d1755801020429f"
+    "### Generate dataset\n",
+    "\n"
    ],
    "metadata": {
     "collapsed": false
@@ -784,47 +786,47 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 21,
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       "Length of generated sequences:\n",
-      " - max: 5107\n",
-      " - avg: 1976.0964912280701\n",
+      " - max: 4816\n",
+      " - avg: 2273.424778761062\n",
       "Longest sequences:\n",
-      "FoR: 5107\n",
-      "DivRuW: 5097\n",
-      "AfP: 4519\n",
-      "StAZ: 4418\n",
-      "DÖD: 4220\n",
-      "ECFR: 3519\n",
-      "APR: 3445\n",
-      "CB: 3387\n",
-      "AuA: 3317\n",
-      "HRN: 3128\n"
+      "JurBüro: 4816\n",
+      "AusR: 4697\n",
+      "AcP: 4010\n",
+      "AW-Prax: 3880\n",
+      "DÖD: 3870\n",
+      "DivRuW: 3867\n",
+      "AuA: 3706\n",
+      "StAZ: 3601\n",
+      "ANA-ZAR: 3361\n",
+      "AuAS: 3322\n"
      ]
     }
    ],
    "source": [
     "from lib.prepare_training_data import create_training_file\n",
     "\n",
-    "prompt = f\"\"\"\n",
-    "# instructions\n",
+    "gemma_instruction = f\"\"\"\n",
+    "# Instructions\n",
     "{system_message}\n",
-    "# user\n",
+    "# User\n",
     "{instruction}\n",
     "{epilog}'\n",
     "\"\"\".strip()\n",
     "\n",
-    "def template_fn(prompt: str, answer: str):\n",
-    "    return f'<bos><start_of_turn>user\\n{prompt}<end_of_turn>\\n<start_of_turn>model\\n{answer}<end_of_turn><eos>'\n",
+    "def template_fn(instruction: str, content: str, answer: str):\n",
+    "    return f'<bos><start_of_turn>user\\n{instruction}\\n\\n{content}<end_of_turn>\\n<start_of_turn>model\\n{answer}<end_of_turn><eos>'\n",
     "\n",
-    "create_training_file(instruction=instruction,\n",
+    "create_training_file(instruction=gemma_instruction,\n",
     "                     template_func=template_fn,\n",
     "                     input_file='data/editors/editors.csv',\n",
-    "                     output_dir='data/editors-gemma',\n",
+    "                     output_dir='data/editors/gemma',\n",
     "                     content_dir='data/editors/website-data',\n",
     "                     max_chars=6000, max_gt_items=5,\n",
     "                     record_identifier_col=\"journal_abbr\",\n",
@@ -835,21 +837,226 @@
    "metadata": {
     "collapsed": false,
     "ExecuteTime": {
-     "end_time": "2024-03-01T08:11:37.452870Z",
-     "start_time": "2024-03-01T08:11:32.919339Z"
+     "end_time": "2024-03-01T18:10:08.910617Z",
+     "start_time": "2024-03-01T18:10:07.794209Z"
     }
    },
    "id": "8d61e8cf63aa5965"
   },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "### Finetuning"
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "55ac22c3a4e1305e"
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Loading pretrained model\r\n",
+      "Fetching 8 files: 100%|██████████████████████████| 8/8 [00:00<00:00, 389.29it/s]\r\n",
+      "Total parameters 1998.171M\r\n",
+      "Trainable parameters 0.459M\r\n",
+      "Loading datasets\r\n",
+      "Training\r\n",
+      "Starting training..., iters: 600\r\n",
+      "Iter 1: Val loss 6.064, Val took 144.925s\r\n",
+      "Iter 10: Train loss 5.106, Learning Rate 1.000e-05, It/sec 0.085, Tokens/sec 60.820, Trained Tokens 7127\r\n",
+      "Iter 20: Train loss 4.203, Learning Rate 1.000e-05, It/sec 0.063, Tokens/sec 49.259, Trained Tokens 14895\r\n",
+      "Iter 30: Train loss 3.695, Learning Rate 1.000e-05, It/sec 0.086, Tokens/sec 51.924, Trained Tokens 20904\r\n",
+      "Iter 40: Train loss 3.638, Learning Rate 1.000e-05, It/sec 0.125, Tokens/sec 60.230, Trained Tokens 25739\r\n",
+      "Iter 50: Train loss 3.091, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 51.838, Trained Tokens 32501\r\n",
+      "Iter 60: Train loss 2.691, Learning Rate 1.000e-05, It/sec 0.086, Tokens/sec 56.203, Trained Tokens 39053\r\n",
+      "Iter 70: Train loss 2.525, Learning Rate 1.000e-05, It/sec 0.061, Tokens/sec 45.969, Trained Tokens 46546\r\n",
+      "Iter 80: Train loss 2.313, Learning Rate 1.000e-05, It/sec 0.096, Tokens/sec 55.637, Trained Tokens 52342\r\n",
+      "Iter 90: Train loss 1.986, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 51.038, Trained Tokens 58944\r\n",
+      "Iter 100: Train loss 1.932, Learning Rate 1.000e-05, It/sec 0.088, Tokens/sec 56.251, Trained Tokens 65370\r\n",
+      "Iter 100: Saved adapter weights to checkpoints/100_editors.npz.\r\n",
+      "Iter 110: Train loss 1.745, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 49.590, Trained Tokens 72183\r\n",
+      "Iter 120: Train loss 1.531, Learning Rate 1.000e-05, It/sec 0.108, Tokens/sec 57.670, Trained Tokens 77506\r\n",
+      "Iter 130: Train loss 1.817, Learning Rate 1.000e-05, It/sec 0.090, Tokens/sec 56.202, Trained Tokens 83737\r\n",
+      "Iter 140: Train loss 1.358, Learning Rate 1.000e-05, It/sec 0.099, Tokens/sec 59.363, Trained Tokens 89733\r\n",
+      "Iter 150: Train loss 1.517, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 49.477, Trained Tokens 96513\r\n",
+      "Iter 160: Train loss 1.524, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 49.295, Trained Tokens 102887\r\n",
+      "Iter 170: Train loss 1.316, Learning Rate 1.000e-05, It/sec 0.071, Tokens/sec 50.801, Trained Tokens 109993\r\n",
+      "Iter 180: Train loss 1.440, Learning Rate 1.000e-05, It/sec 0.092, Tokens/sec 55.542, Trained Tokens 116060\r\n",
+      "Iter 190: Train loss 1.402, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 53.860, Trained Tokens 123028\r\n",
+      "Iter 200: Train loss 1.578, Learning Rate 1.000e-05, It/sec 0.069, Tokens/sec 53.572, Trained Tokens 130758\r\n",
+      "Iter 200: Val loss 1.449, Val took 136.326s\r\n",
+      "Iter 200: Saved adapter weights to checkpoints/200_editors.npz.\r\n",
+      "Iter 210: Train loss 1.403, Learning Rate 1.000e-05, It/sec 0.068, Tokens/sec 48.232, Trained Tokens 137872\r\n",
+      "Iter 220: Train loss 1.596, Learning Rate 1.000e-05, It/sec 0.079, Tokens/sec 55.729, Trained Tokens 144886\r\n",
+      "Iter 230: Train loss 1.324, Learning Rate 1.000e-05, It/sec 0.097, Tokens/sec 57.306, Trained Tokens 150820\r\n",
+      "Iter 240: Train loss 1.308, Learning Rate 1.000e-05, It/sec 0.090, Tokens/sec 55.030, Trained Tokens 156927\r\n",
+      "Iter 250: Train loss 1.394, Learning Rate 1.000e-05, It/sec 0.087, Tokens/sec 58.270, Trained Tokens 163638\r\n",
+      "Iter 260: Train loss 1.424, Learning Rate 1.000e-05, It/sec 0.085, Tokens/sec 57.969, Trained Tokens 170473\r\n",
+      "Iter 270: Train loss 1.299, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 42.263, Trained Tokens 175687\r\n",
+      "Iter 280: Train loss 1.489, Learning Rate 1.000e-05, It/sec 0.055, Tokens/sec 45.052, Trained Tokens 183919\r\n",
+      "Iter 290: Train loss 1.248, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 53.325, Trained Tokens 190540\r\n",
+      "Iter 300: Train loss 1.277, Learning Rate 1.000e-05, It/sec 0.077, Tokens/sec 52.184, Trained Tokens 197354\r\n",
+      "Iter 300: Saved adapter weights to checkpoints/300_editors.npz.\r\n",
+      "Iter 310: Train loss 1.346, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 57.205, Trained Tokens 204429\r\n",
+      "Iter 320: Train loss 1.311, Learning Rate 1.000e-05, It/sec 0.087, Tokens/sec 51.852, Trained Tokens 210397\r\n",
+      "Iter 330: Train loss 1.398, Learning Rate 1.000e-05, It/sec 0.085, Tokens/sec 54.920, Trained Tokens 216829\r\n",
+      "Iter 340: Train loss 1.370, Learning Rate 1.000e-05, It/sec 0.091, Tokens/sec 57.985, Trained Tokens 223184\r\n",
+      "Iter 350: Train loss 1.099, Learning Rate 1.000e-05, It/sec 0.088, Tokens/sec 55.375, Trained Tokens 229472\r\n",
+      "Iter 360: Train loss 1.325, Learning Rate 1.000e-05, It/sec 0.061, Tokens/sec 44.221, Trained Tokens 236775\r\n",
+      "Iter 370: Train loss 1.140, Learning Rate 1.000e-05, It/sec 0.095, Tokens/sec 56.221, Trained Tokens 242704\r\n",
+      "Iter 380: Train loss 1.195, Learning Rate 1.000e-05, It/sec 0.091, Tokens/sec 53.397, Trained Tokens 248601\r\n",
+      "Iter 390: Train loss 1.655, Learning Rate 1.000e-05, It/sec 0.057, Tokens/sec 49.824, Trained Tokens 257341\r\n",
+      "Iter 400: Train loss 1.144, Learning Rate 1.000e-05, It/sec 0.112, Tokens/sec 59.800, Trained Tokens 262675\r\n",
+      "Iter 400: Val loss 1.406, Val took 141.987s\r\n",
+      "Iter 400: Saved adapter weights to checkpoints/400_editors.npz.\r\n",
+      "Iter 410: Train loss 1.220, Learning Rate 1.000e-05, It/sec 0.090, Tokens/sec 55.037, Trained Tokens 268819\r\n",
+      "Iter 420: Train loss 1.454, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 51.010, Trained Tokens 275828\r\n",
+      "Iter 430: Train loss 1.195, Learning Rate 1.000e-05, It/sec 0.081, Tokens/sec 55.343, Trained Tokens 282670\r\n",
+      "Iter 440: Train loss 1.251, Learning Rate 1.000e-05, It/sec 0.078, Tokens/sec 49.881, Trained Tokens 289097\r\n",
+      "Iter 450: Train loss 1.176, Learning Rate 1.000e-05, It/sec 0.080, Tokens/sec 50.116, Trained Tokens 295391\r\n",
+      "Iter 460: Train loss 1.249, Learning Rate 1.000e-05, It/sec 0.070, Tokens/sec 48.412, Trained Tokens 302328\r\n",
+      "Iter 470: Train loss 1.134, Learning Rate 1.000e-05, It/sec 0.105, Tokens/sec 59.645, Trained Tokens 308020\r\n",
+      "Iter 480: Train loss 1.295, Learning Rate 1.000e-05, It/sec 0.075, Tokens/sec 52.581, Trained Tokens 315053\r\n",
+      "Iter 490: Train loss 1.182, Learning Rate 1.000e-05, It/sec 0.093, Tokens/sec 58.488, Trained Tokens 321354\r\n",
+      "Iter 500: Train loss 1.344, Learning Rate 1.000e-05, It/sec 0.082, Tokens/sec 52.577, Trained Tokens 327745\r\n",
+      "Iter 500: Saved adapter weights to checkpoints/500_editors.npz.\r\n",
+      "Iter 510: Train loss 1.183, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 52.934, Trained Tokens 335029\r\n",
+      "Iter 520: Train loss 1.449, Learning Rate 1.000e-05, It/sec 0.067, Tokens/sec 51.610, Trained Tokens 342699\r\n",
+      "Iter 530: Train loss 1.332, Learning Rate 1.000e-05, It/sec 0.078, Tokens/sec 50.195, Trained Tokens 349121\r\n",
+      "Iter 540: Train loss 1.293, Learning Rate 1.000e-05, It/sec 0.091, Tokens/sec 54.443, Trained Tokens 355075\r\n",
+      "Iter 550: Train loss 1.444, Learning Rate 1.000e-05, It/sec 0.063, Tokens/sec 49.679, Trained Tokens 363018\r\n",
+      "Iter 560: Train loss 1.197, Learning Rate 1.000e-05, It/sec 0.073, Tokens/sec 50.547, Trained Tokens 369947\r\n",
+      "Iter 570: Train loss 1.257, Learning Rate 1.000e-05, It/sec 0.057, Tokens/sec 44.008, Trained Tokens 377730\r\n",
+      "Iter 580: Train loss 1.160, Learning Rate 1.000e-05, It/sec 0.075, Tokens/sec 50.617, Trained Tokens 384486\r\n",
+      "Iter 590: Train loss 1.214, Learning Rate 1.000e-05, It/sec 0.102, Tokens/sec 55.789, Trained Tokens 389956\r\n",
+      "Iter 600: Train loss 0.879, Learning Rate 1.000e-05, It/sec 0.137, Tokens/sec 61.512, Trained Tokens 394433\r\n",
+      "Iter 600: Val loss 1.367, Val took 137.833s\r\n",
+      "Iter 600: Saved adapter weights to checkpoints/600_editors.npz.\r\n",
+      "Saved final adapter weights to editors.npz.\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "from mlx_lm.utils import get_model_path\n",
+    "import os\n",
+    "os.environ['MODEL_NAME'] = model_name = 'mlx-community/quantized-gemma-7b-it'\n",
+    "\n",
+    " \n",
+    "!python -m mlx_lm.lora \\\n",
+    "    --model \"$MODEL_NAME\" \\\n",
+    "    --adapter-file \"editors.npz\" \\\n",
+    "    --train \\\n",
+    "    --iters 600 --batch-size 1 --lora-layers 4 \\\n",
+    "    --data data/editors/gemma"
+   ],
+   "metadata": {
+    "collapsed": false,
+    "ExecuteTime": {
+     "end_time": "2024-03-02T19:40:14.555110Z",
+     "start_time": "2024-03-02T15:14:52.368721Z"
+    }
+   },
+   "id": "a9da1ee7b6cc7997"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "Iter 600: Val loss 1.367, Val took 137.833s"
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "7d5ea07d81750fed"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "### Testing"
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "57f675d7cdfd2965"
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "outputs": [
+    {
+     "data": {
+      "text/plain": "Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]",
+      "application/vnd.jupyter.widget-view+json": {
+       "version_major": 2,
+       "version_minor": 0,
+       "model_id": "eceea48ae37a4b75be5f13369be5ff9f"
+      }
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Loading pretrained model\r\n",
+      "Fetching 8 files: 100%|████████████████████████| 8/8 [00:00<00:00, 42799.02it/s]\r\n",
+      "Total parameters 1999.547M\r\n",
+      "Trainable parameters 1.835M\r\n",
+      "Loading datasets\r\n",
+      "Testing\r\n",
+      "Test loss 1.395, Test ppl 4.035.\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n",
+    "os.environ['MODEL_PATH'] = str(get_model_path(model_name))\n",
+    "!python -m mlx_lm.lora \\\n",
+    "    --model \"$MODEL_NAME\" \\\n",
+    "    --adapter-file \"editors.npz\" \\\n",
+    "    --data data/editors/gemma \\\n",
+    "    --test"
+   ],
+   "metadata": {
+    "collapsed": false,
+    "ExecuteTime": {
+     "end_time": "2024-03-03T10:30:03.579158Z",
+     "start_time": "2024-03-03T10:21:41.778749Z"
+    }
+   },
+   "id": "de978facc2c3c978"
+  },
+  {
+   "cell_type": "markdown",
+   "source": [
+    "Test loss 1.395, Test ppl 4.035."
+   ],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "a39b638de1705a79"
+  },
   {
    "cell_type": "code",
    "execution_count": null,
    "outputs": [],
-   "source": [],
+   "source": [
+    "# Load the fine-tuned model with LoRA weights\n",
+    "model_lora, _ = load(\n",
+    "    \"mlx-community/quantized-gemma-7b-it\",\n",
+    "    adapter_file=\"./editors.npz\",  # adapters.npz is the final checkpoint saved at the end of training\n",
+    ")"
+   ],
    "metadata": {
     "collapsed": false
    },
-   "id": "db51ef32ff18dff3"
+   "id": "7a0a937fa126c7ad"
   }
  ],
  "metadata": {
diff --git a/mlx/lora/finetune-reference-extraction.ipynb b/mlx/lora/finetune-reference-extraction.ipynb
new file mode 100644
index 0000000..af40444
--- /dev/null
+++ b/mlx/lora/finetune-reference-extraction.ipynb
@@ -0,0 +1,65 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "initial_id",
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "# Reference extraction with Google Gemma\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "outputs": [],
+   "source": [
+    "from lib.reference_extraction import load_xmls_from_directory\n",
+    "\n"
+   ],
+   "metadata": {
+    "collapsed": false,
+    "ExecuteTime": {
+     "end_time": "2024-03-02T15:26:03.329751Z",
+     "start_time": "2024-03-02T15:25:58.279349Z"
+    }
+   },
+   "id": "6e7ba19d6cf1d66b"
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "outputs": [],
+   "source": [],
+   "metadata": {
+    "collapsed": false
+   },
+   "id": "92cd78ba5d9944c4"
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 2
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython2",
+   "version": "2.7.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/mlx/lora/lib/reference_extraction.py b/mlx/lora/lib/reference_extraction.py
new file mode 100644
index 0000000..aa5073a
--- /dev/null
+++ b/mlx/lora/lib/reference_extraction.py
@@ -0,0 +1,184 @@
+# generated by GPT-4
+import re
+import pandas as pd
+import os
+import yaml
+import json
+from sklearn.model_selection import train_test_split
+import numpy as np
+import textwrap
+from lxml import etree
+
+
+def xml_to_dataframe(xml_path):
+    # Parse the XML file
+    tree = etree.parse(xml_path)
+    root = tree.getroot()
+
+    # Initialize an empty list to store row data
+    rows = []
+
+    # Iterate through each sequence in the XML
+    for sequence in root.findall('sequence'):
+        # Initialize an empty dictionary for each sequence
+        row_data = {}
+
+        # Iterate through each child element in the sequence
+        for element in sequence:
+            # Use tag name as column name and text as the value in the row_data dictionary
+            row_data[element.tag] = element.text
+
+        # Append the dictionary to the rows list
+        rows.append(row_data)
+
+    # Convert the list of dictionaries to a DataFrame
+    df = pd.DataFrame(rows)
+
+    # Return the DataFrame
+    return df
+
+import os
+
+def load_xmls_from_directory(directory_path):
+    # Initialize an empty DataFrame to store all data
+    all_data = pd.DataFrame()
+
+    # Iterate over every file in the directory
+    for filename in os.listdir(directory_path):
+        # Construct the full file path
+        file_path = os.path.join(directory_path, filename)
+
+        # Check if the file is an XML file
+        if os.path.isfile(file_path) and file_path.endswith('.xml'):
+            # Load the XML file into a DataFrame
+            df = xml_to_dataframe(file_path)
+
+            # Append the data to the all_data DataFrame
+            all_data = pd.concat([all_data, df], ignore_index=True)
+
+    # Return the combined DataFrame
+    return all_data
+
+
+
+def clean_values(row):
+    """Return a new dictionary with no NaN, None, or empty/whitespace-only values."""
+    return {k: v for k, v in row.items() if pd.notna(v) and v is not None and v.strip()}
+
+def create_training_files(instruction: str,
+                         template_func: callable,
+                         input_file: str, output_dir: str, content_dir: str,
+                         cols_to_remove: list, column_to_filter_by: str,
+                         record_identifier_col: str,
+                         max_chars=2048 * 4, max_gt_items=10,
+                         line_width=120, lines_before=1, lines_after=1, random_seed=None,
+                         debug=True):
+
+
+    # Load the input
+    df = pd.read_csv(input_file)
+
+    # Set or generate a random seed
+    if random_seed is None:
+        random_seed = np.random.randint(0, 10000)
+
+    # Prepare the output directory
+    os.makedirs(output_dir, exist_ok=True)
+
+    # Keep track of the longest prompts
+    id_value_pairs = {}
+
+    def process_and_write_data(grouped_df, file, write_debug_file=False):
+        for id, group in grouped_df.groupby(record_identifier_col):
+
+            # use only the first 10 ground truth items so that the training record does not become too large
+            group = group.head(max_gt_items)
+
+            # load website content from the cache
+            filename = f"{content_dir}/{re.sub(r'[. ]', '', str(id))}.txt"
+            try:
+                with open(filename, 'r', encoding='utf-8') as content_file:
+                    content = content_file.read()
+            except FileNotFoundError:
+                continue
+
+            # Clean up data
+            answer_df = group.drop(cols_to_remove, axis=1, errors='ignore').dropna(axis=1, how='all')
+
+            # Convert DataFrame rows to YAML dictionaries, cleaning NaN values
+            cleaned_rows = answer_df.apply(lambda row: clean_values(row.to_dict()), axis=1)
+            answer = yaml.dump(list(cleaned_rows), allow_unicode=True, sort_keys=False)
+
+            # wrap content to decrease token window
+            lines = [line for line in wrap_content_generator(content, width=line_width)]
+            wrapped_content = '\n'.join(lines)
+
+            # Determine how many characters are available for the content
+            max_chars_for_content = max_chars - len(instruction) - len(answer)
+
+            # Determine which keywords to filter by
+            keyword_list = [x for x in group[column_to_filter_by].tolist() if pd.notnull(x)]
+
+            # Filter rows, using keywords and a context window
+            filtered_content = filter_content_with_context(wrapped_content,
+                                                           keywords=keyword_list,
+                                                           lines_before=lines_before, lines_after=lines_after,
+                                                           max_chars=max_chars_for_content)
+
+            # Ignore if nothing was found
+            if filtered_content.strip() == '':
+                continue
+
+            if write_debug_file:
+                sequence = [
+                    '### JOURNAL', id,
+                    '### URL', group['website'].tolist()[0],
+                    '### CONTENT', filtered_content,
+                    '### ANSWER', answer,
+                    '-' * line_width,
+                    ''
+                ]
+                file.write('\n\n'.join(sequence))
+            else:
+                sequence = template_func(instruction, filtered_content, answer)
+                train_json = {
+                    "id": id,
+                    "text": sequence
+                }
+                file.write(json.dumps(train_json) + '\n')
+                id_value_pairs[id] = len(sequence)
+
+    # Split the DataFrame into training (80%) and test+validation (20%)
+    train_df, test_valid_df = train_test_split(df, test_size=0.2, random_state=random_seed)
+
+    # Further split the test+validation into test and validation (50% each of the 20%)
+    test_df, valid_df = train_test_split(test_valid_df, test_size=0.5, random_state=random_seed)
+
+    # Debug file for better readability of the result
+    if debug:
+        debug_instruction = instruction.replace('\n', '\n\n')
+        debug_instruction = "\n".join(textwrap.wrap(debug_instruction, width=line_width))
+        with open(f'{output_dir}/debug.txt', 'w', encoding='utf-8') as debug_file:
+            debug_file.write(f'### INSTRUCTION\n\n{debug_instruction}\n\n{"=" * line_width}\n\n')
+            process_and_write_data(df, debug_file, write_debug_file=True)
+
+    # Write train, test and validation files
+    with open(f'{output_dir}/train.jsonl', 'w', encoding='utf-8') as train_file, \
+            open(f'{output_dir}/test.jsonl', 'w', encoding='utf-8') as test_file, \
+            open(f'{output_dir}/valid.jsonl', 'w', encoding='utf-8') as valid_file:
+
+        # Process and write data to each file
+        process_and_write_data(train_df, train_file)
+        process_and_write_data(test_df, test_file)
+        process_and_write_data(valid_df, valid_file)
+
+    if debug:
+        print("Length of generated sequences:")
+        print(f" - max: {max(id_value_pairs.values())}")
+        print(f" - avg: {sum(id_value_pairs.values()) / len(id_value_pairs)}")
+
+        sorted_pairs = sorted(id_value_pairs.items(), key=lambda x: x[1], reverse=True)
+        highest_10_values = sorted_pairs[:10]
+        print("Longest sequences:")
+        for _id, value in highest_10_values:
+            print(f"{_id}: {value}")
-- 
GitLab