{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Convert AmyStyle training data to a simple JSONL format",
   "id": "ae7e001161d678cc"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-12T15:57:59.317356Z",
     "start_time": "2024-07-12T15:57:59.198914Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import xml.etree.ElementTree as ET\n",
    "import json\n",
    "import regex as re\n",
    "import string\n",
    "import glob\n",
    "import os\n",
    "\n",
    "def xml_to_jsonl(input_xml_path, output_jsonl_path, tags):\n",
    "    tree = ET.parse(input_xml_path)\n",
    "    root = tree.getroot()\n",
    "\n",
    "    with open(output_jsonl_path, 'w', encoding='utf-8') as f:\n",
    "        for sequence in root.findall('sequence'):\n",
    "            output = []\n",
    "            for element in sequence:\n",
    "                for tag in tags:\n",
    "                    if type(tag) is tuple:\n",
    "                        tag, fn = tag\n",
    "                    if element.tag == tag:\n",
    "                        value = fn(element.text) if callable(fn) else element.text\n",
    "                        if len(output) == 0 or tag in output[-1]:\n",
    "                            output.append({}) \n",
    "                        output[-1][tag] = value\n",
    "            if len(output) > 0:\n",
    "                instance = {\n",
    "                    \"in\" : \" \".join(element.text.strip() if element.text else '' for element in sequence),\n",
    "                    \"out\" : output\n",
    "                }\n",
    "                f.write(json.dumps(instance) + '\\n')\n",
    "\n",
    "def remove_punctuation(text):\n",
    "    punctuation = set(string.punctuation)\n",
    "    start, end = 0, len(text)\n",
    "    while start < len(text) and text[start] in punctuation:\n",
    "        start += 1\n",
    "    while end > start and text[end - 1] in punctuation:\n",
    "        end -= 1\n",
    "    return text[start:end].strip()\n",
    "\n",
    "def clean_editor(text):\n",
    "    return remove_punctuation(re.sub(r'hrsg\\. v\\.|hg\\. v|hrsg|ed\\.|eds\\.|in:', '', text, flags=re.IGNORECASE))\n",
    "\n",
    "def clean_container(text):\n",
    "    return remove_punctuation(re.sub(r'in:|aus:|from:', '', text, flags=re.IGNORECASE))\n",
    "    \n",
    "def extract_year(text): \n",
    "    m = re.search( r'[12][0-9]{3}', text)\n",
    "    return m.group(0) if m else None\n",
    "\n",
    "for input_file in glob.glob('in/*.xml'):\n",
    "    base_name = os.path.basename(input_file)\n",
    "    output_file = f'out/{os.path.splitext(base_name)[0]}-simple.jsonl'\n",
    "    xml_to_jsonl(input_file, output_file, [\n",
    "        (\"author\", remove_punctuation),\n",
    "        (\"editor\", clean_editor),\n",
    "        (\"authority\", remove_punctuation),\n",
    "        (\"title\", remove_punctuation),\n",
    "        (\"container-title\", clean_container),\n",
    "        (\"journal\", clean_container),\n",
    "        (\"date\", extract_year)\n",
    "    ])\n",
    "\n"
   ],
   "id": "f101a4e2408d6313",
   "outputs": [],
   "execution_count": 30
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "43e2040fed89c0bd"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}