Skip to content
Snippets Groups Projects
Commit e41b3f96 authored by cboulanger's avatar cboulanger
Browse files

Fix training data generation. Refactoring

parent c43dfa9a
No related branches found
No related tags found
No related merge requests found
* *
!example !example
!editors !editors*
\ No newline at end of file \ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Source diff could not be displayed: it is too large. Options to address this: view the blob.
This diff is collapsed.
This diff is collapsed.
...@@ -11,6 +11,7 @@ import json ...@@ -11,6 +11,7 @@ import json
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import numpy as np import numpy as np
import textwrap import textwrap
import sys
def download_website_content(url): def download_website_content(url):
...@@ -130,6 +131,8 @@ def create_training_file(instruction: str, ...@@ -130,6 +131,8 @@ def create_training_file(instruction: str,
max_chars=2048 * 4, max_gt_items=10, max_chars=2048 * 4, max_gt_items=10,
line_width=120, lines_before=1, lines_after=1, random_seed=None, line_width=120, lines_before=1, lines_after=1, random_seed=None,
debug=True): debug=True):
# Load the input # Load the input
df = pd.read_csv(input_file) df = pd.read_csv(input_file)
...@@ -160,16 +163,16 @@ def create_training_file(instruction: str, ...@@ -160,16 +163,16 @@ def create_training_file(instruction: str,
# Clean up data # Clean up data
answer_df = group.drop(cols_to_remove, axis=1, errors='ignore').dropna(axis=1, how='all') answer_df = group.drop(cols_to_remove, axis=1, errors='ignore').dropna(axis=1, how='all')
# Convert DataFrame rows to dictionaries, cleaning NaN values # Convert DataFrame rows to YAML dictionaries, cleaning NaN values
cleaned_rows = answer_df.apply(lambda row: clean_values(row.to_dict()), axis=1) cleaned_rows = answer_df.apply(lambda row: clean_values(row.to_dict()), axis=1)
answer_yaml = yaml.dump(list(cleaned_rows), allow_unicode=True, sort_keys=False) answer = yaml.dump(list(cleaned_rows), allow_unicode=True, sort_keys=False)
# wrap content to decrease token window # wrap content to decrease token window
lines = [line for line in wrap_content_generator(content, width=line_width)] lines = [line for line in wrap_content_generator(content, width=line_width)]
wrapped_content = '\n'.join(lines) wrapped_content = '\n'.join(lines)
# Determine how many characters are available for the content # Determine how many characters are available for the content
max_chars_for_content = max_chars - len(instruction) - len(answer_yaml) max_chars_for_content = max_chars - len(instruction) - len(answer)
# Determine which keywords to filter by # Determine which keywords to filter by
keyword_list = [x for x in group[column_to_filter_by].tolist() if pd.notnull(x)] keyword_list = [x for x in group[column_to_filter_by].tolist() if pd.notnull(x)]
...@@ -185,17 +188,17 @@ def create_training_file(instruction: str, ...@@ -185,17 +188,17 @@ def create_training_file(instruction: str,
continue continue
if write_debug_file: if write_debug_file:
prompt = [ sequence = [
'### JOURNAL', id, '### JOURNAL', id,
'### URL', group['website'].tolist()[0], '### URL', group['website'].tolist()[0],
'### CONTENT', filtered_content, '### CONTENT', filtered_content,
'### ANSWER', answer_yaml, '### ANSWER', answer,
'-' * line_width, '-' * line_width,
'' ''
] ]
file.write('\n\n'.join(prompt)) file.write('\n\n'.join(sequence))
else: else:
sequence = template_func(f'{instruction}\n\n{filtered_content}', answer_yaml) sequence = template_func(instruction, filtered_content, answer)
train_json = { train_json = {
"id": id, "id": id,
"text": sequence "text": sequence
...@@ -211,10 +214,10 @@ def create_training_file(instruction: str, ...@@ -211,10 +214,10 @@ def create_training_file(instruction: str,
# Debug file for better readability of the result # Debug file for better readability of the result
if debug: if debug:
instruction = instruction.replace('\n', '\n\n') debug_instruction = instruction.replace('\n', '\n\n')
instruction = "\n".join(textwrap.wrap(instruction, width=line_width)) debug_instruction = "\n".join(textwrap.wrap(debug_instruction, width=line_width))
with open(f'{output_dir}/debug.txt', 'w', encoding='utf-8') as debug_file: with open(f'{output_dir}/debug.txt', 'w', encoding='utf-8') as debug_file:
debug_file.write(f'### INSTRUCTION\n\n{instruction}\n\n{"=" * line_width}\n\n') debug_file.write(f'### INSTRUCTION\n\n{debug_instruction}\n\n{"=" * line_width}\n\n')
process_and_write_data(df, debug_file, write_debug_file=True) process_and_write_data(df, debug_file, write_debug_file=True)
# Write train, test and validation files # Write train, test and validation files
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment