Skip to content

Training Classifier Starter Guide

This pages teach you how to train high-precision classifiers to determine the relevance and faithfulness of RAG outputs


Training Classifier Configuration

The synth_config dictionary is a configuration object that sets up ARES for generating synthetic queries based on a given dataset. Below is how the training classifier configuration style.

from ares import ARES

classifier_config = {
    "classification_dataset": [<classification_dataset_filepath>],
    "test_set_selection": <test_set_selection_filepath>, 
    "label_column": [<labels>], 
    "model_choice": "microsoft/deberta-v3-large", # Default model is "microsoft/deberta-v3-large"
    "num_epochs": 10, 
    "patience_value": 3, 
    "learning_rate": 5e-6
}

ares = ARES(classifier_model=classifier_config)
results = ares.train_classifier()
print(results)

Classification Dataset

Generated from the ARES synthetic generator, here you should provide a list of file paths or an individual filepath to your labeled dataset used for training the classifier. The dataset should include text data and corresponding labels for supervised learning.

"classification_dataset": ["output/synthetic_queries_1.tsv"],

Test Set Selection

Provide the file path to your test set for evaluating the classifier's performance. This should be separate from the training data to ensure an unbiased assessment.

"test_set_selection": "/data/datasets_v2/nq/nq_ratio_0.6_.tsv"

Link to ARES Github Repo for test set selection file example used.

Label Column(s)

List the column name(s) in your dataset that contain the label(s). These are the targets your classifier will predict.

"label_column": ["Conmtext_Relevance_Label"], 

Model Choice

Specifies the pre-trained language model to fine-tune for classification. By default, ARES uses "microsoft/deberta-v3-large". You can replace this with any Hugging Face model suitable for your task.

 "model_choice": "google/flan-t5-xxl",

Num Epochs

Determines the number of training epochs, which is the number of times the learning algorithm will work through the entire training dataset.

"num_epochs": 10, 

Patience Value

This is used in early stopping to prevent overfitting. It's the number of epochs with no improvement on the validation set after which training will be stopped.

"patience_value": 3, 

Learning Rate

Sets the initial learning rate for the optimizer. This is a crucial hyperparameter that controls the adjustment of model weights during training.

 "learning_rate": 5e-6

Training Classifier Configuration: Full Example

from ares import ARES

classifier_config = {
    "classification_dataset": ["output/synthetic_queries_1.tsv"], 
    "validation_set": "./datasets_v2/nq/ratio_0.5_reformatted_full_articles_False_validation_with_negatives.tsv",
    "label_column": ["Context_Relevance_Label"], 
    "model_choice": "microsoft/deberta-v3-large",
    "num_epochs": 10, 
    "patience_value": 3, 
    "learning_rate": 5e-6
}

ares = ARES(classifier_model=classifier_config)
results = ares.train_classifier()
print(results)