Fine-tuning a BERT model for search applications

How to ensure training and serving encoding compatibility

There are cases where the inputs to your Transformer model are pairs of sentences, but you want to process each sentence of the pair at different times due to your application’s nature. Search applications are one example.

Photo by Alice Dietrich on Unsplash

The search use case

Search applications involve a large collection of documents that can be pre-processed and stored before a search action is required. On the other hand, a query triggers a search action, and we can only process it in real-time. Search apps’ goal is to return the most relevant documents to the query as quickly as possible. By applying the tokenizer to the documents as soon as we feed them to the application, we only need to tokenize the query when a search action is required, saving us precious time.

In addition to applying the tokenizer at different times, you also want to retain adequate control about encoding your pair of sentences. For search, you might want to have a joint input vector of length 128 where the query, which is usually smaller than the document, contributes with 32 tokens while the document can take up to 96 tokens.

Training and serving compatibility

When training a Transformer model for search, you want to ensure that the training data will follow the same pattern used by the search engine serving the final model. I have written a blog post on how to get started with BERT model fine-tuning using the transformer library. This piece will adapt the training routine with a custom encoding based on two separate tokenizers to reproduce how a Vespa application would serve the model once deployed.

Create independent BERT encodings

The only change required is simple but essential. In my previous post, we discussed the vanilla case where we applied the tokenizer directly to the pairs of queries and documents.

from transformers import BertTokenizerFast

model_name = "google/bert_uncased_L-4_H-512_A-8"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = tokenizer(
    train_queries, train_docs, truncation=True, 
    padding='max_length', max_length=128
)
val_encodings = tokenizer(
    val_queries, val_docs, truncation=True, 
    padding='max_length', max_length=128
)

In the search case, we create the create_bert_encodings function that will apply two different tokenizers, one for the query and the other for the document. In addition to allowing for different query and document max_length, we also need to set add_special_tokens=False and not use padding as those need to be included by our custom code when joining the tokens generated by the tokenizer.

def create_bert_encodings(
    queries, docs, tokenizer, query_input_size, doc_input_size
):
    queries_encodings = tokenizer(
        queries, truncation=True, 
        max_length=query_input_size-2, add_special_tokens=False
    )
    docs_encodings = tokenizer(
        docs, truncation=True, 
        max_length=doc_input_size-1, add_special_tokens=False
    )

    TOKEN_NONE=0
    TOKEN_CLS=101
    TOKEN_SEP=102

    input_ids = []
    token_type_ids = []
    attention_mask = []
    for query_input_ids, doc_input_ids in zip(
        queries_encodings["input_ids"], docs_encodings["input_ids"]
    ):
        # create input id
        input_id = [TOKEN_CLS] + query_input_ids + [TOKEN_SEP] + 
            doc_input_ids + [TOKEN_SEP]
        number_tokens = len(input_id)
        padding_length = max(128 - number_tokens, 0)
        input_id = input_id + [TOKEN_NONE] * padding_length
        input_ids.append(input_id)
        # create token id
        token_type_id = [0] * 
            len([TOKEN_CLS] + query_input_ids + [TOKEN_SEP]) + 
            [1] * len(doc_input_ids + [TOKEN_SEP]) + 
            [TOKEN_NONE] * padding_length
        token_type_ids.append(token_type_id)
        # create attention_mask
        attention_mask.append(
            [1] * number_tokens + [TOKEN_NONE] * padding_length
        )

    encodings = {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask
    }
    return encodings

We then create the train_encodings and val_encodings required by the training routine. Everything else on the training routine works just the same.

from transformers import BertTokenizerFast

model_name = "google/bert_uncased_L-4_H-512_A-8"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = create_bert_encodings(
    queries=train_queries, 
    docs=train_docs, 
    tokenizer=tokenizer, 
    query_input_size=32, 
    doc_input_size=96
)

val_encodings = create_bert_encodings(
    queries=val_queries, 
    docs=val_docs, 
    tokenizer=tokenizer, 
    query_input_size=32, 
    doc_input_size=96
)

Conclusion and future work

Training a model to deploy in a search application require us to ensure that the training encodings are compatible with encodings used at serving time. We generate document encodings offline when feeding the documents to the search engine while creating query encoding at run-time upon arrival of the query. It is often relevant to use different maximum lengths for queries and documents, and other possible configurations.

Photo by Steve Johnson on Unsplash

We showed how to customize BERT model encodings to ensure this training and serving compatibility. However, a better approach is to build tools that bridge the gap between training and serving by allowing users to request training data that respects by default the encodings used when serving the model. pyvespa will include such integration to make it easier for Vespa users to train BERT models without having to adjust the encoding generation manually as we did above.

How to evaluate Vespa ranking functions from python

Using pyvespa to evaluate cord19 search application ranking functions currently in production.

This is the second on a series of blog posts that will show you how to improve a text search application, from downloading data to fine-tuning BERT models.

The previous post showed how to download and parse TREC-COVID data. This one will focus on evaluating two query models available in the cord19 search application. Those models will serve as baselines for future improvements.

You can also run the steps contained here from Google Colab.

Photo by Agence Olloweb on Unsplash

Download processed data

We can start by downloading the data that we have processed before.

import requests, json
from pandas import read_csv

topics = json.loads(requests.get(
  "https://thigm85.github.io/data/cord19/topics.json").text
)
relevance_data = read_csv(
  "https://thigm85.github.io/data/cord19/relevance_data.csv"
)

topics contain data about the 50 topics available, including query, question and narrative.

topics["1"]
{'query': 'coronavirus origin',
 'question': 'what is the origin of COVID-19',
 'narrative': "seeking range of information about the SARS-CoV-2 virus's origin, including its evolution, animal source, and first transmission into humans"}

relevance_data contains the relevance judgments for each of the 50 topics.

relevance_data.head(5)
png

Install pyvespa

We are going to use pyvespa to evaluate ranking functions from python.

!pip install pyvespa

pyvespa provides a python API to Vespa. It allows us to create, modify, deploy, and interact with running Vespa instances. The library’s main goal is to allow for faster prototyping and facilitate Machine Learning experiments for Vespa applications.

Format the labeled data into expected pyvespa format

pyvespa expects labeled data to follow the format illustrated below. It is a list of dict where each dict represents a query containing query_id, query and a list of relevant_docs. Each relevant document contains a required id key and an optional score key.

labeled_data = [
    {
        'query_id': 1,
        'query': 'coronavirus origin',
        'relevant_docs': [{'id': '005b2j4b', 'score': 2}, {'id': '00fmeepz', 'score': 1}]
    },
    {
        'query_id': 2,
        'query': 'coronavirus response to weather changes',
        'relevant_docs': [{'id': '01goni72', 'score': 2}, {'id': '03h85lvy', 'score': 2}]
    }
]

We can create labeled_data from the topics and relevance_data that we downloaded before. We are only going to include documents with relevance score > 0 into the final list.

labeled_data = [
    {
        "query_id": int(topic_id), 
        "query": topics[topic_id]["query"], 
        "relevant_docs": [
            {
                "id": row["cord_uid"], 
                "score": row["relevancy"]
            } for idx, row in relevance_data[relevance_data.topic_id == int(topic_id)].iterrows() if row["relevancy"] > 0
        ]
    } for topic_id in topics.keys()]

Define query models to be evaluated

We are going to define two query models to be evaluated here. Both will match all the documents that share at least one term with the query. This is defined by setting match_phase = OR().

The difference between the query models happens in the ranking phase. The or_default model will rank documents based on nativeRank, while the or_bm25 model will rank documents based on BM25. Discussion about those two types of ranking is out of the scope of this tutorial. It is enough to know that they rank documents according to two different formulas.

Those ranking profiles were defined by the team behind the cord19 app and can be found here.

from vespa.query import Query, RankProfile, OR

query_models = {
    "or_default": Query(
        match_phase = OR(),
        rank_profile = RankProfile(name="default")
    ),
    "or_bm25": Query(
        match_phase = OR(),
        rank_profile = RankProfile(name="bm25t5")
    )
}

Define metrics to be used in the evaluation

We want to compute the following metrics:

  • The percentage of documents matched by the query
  • Recall @ 10
  • Reciprocal rank @ 10
  • NDCG @ 10
from vespa.evaluation import MatchRatio, Recall, ReciprocalRank, NormalizedDiscountedCumulativeGain

eval_metrics = [
    MatchRatio(), 
    Recall(at=10), 
    ReciprocalRank(at=10), 
    NormalizedDiscountedCumulativeGain(at=10)
]

Evaluate

Connect to a running Vespa instance:

from vespa.application import Vespa

app = Vespa(url = "https://api.cord19.vespa.ai")

Compute the metrics defined above for each query model and store the results in a dictionary.

evaluations = {}
for query_model in query_models:
    evaluations[query_model] = app.evaluate(
        labeled_data = labeled_data,
        eval_metrics = eval_metrics,
        query_model = query_models[query_model],
        id_field = "cord_uid",
        hits = 10
    )

Analyze results

Let’s first combine the data into one DataFrame in a format to facilitate a comparison between query models.

import pandas as pd

metric_values = []
for query_model in query_models:
    for metric in eval_metrics:
        metric_values.append(
            pd.DataFrame(
                data={
                    "query_model": query_model, 
                    "metric": metric.name, 
                    "value": evaluations[query_model][metric.name + "_value"].to_list()
                }
            )
        )
metric_values = pd.concat(metric_values, ignore_index=True)
metric_values.head()
png

We can see below that the query model based on BM25 is superior across all metrics considered here.

metric_values.groupby(['query_model', 'metric']).mean()
png

We can also visualize the metrics’ distribution across the queries to get a better picture of the results.

import plotly.express as px

fig = px.box(
    metric_values[metric_values.metric == "ndcg_10"], 
    x="query_model", 
    y="value", 
    title="Ndgc @ 10",
    points="all"
)
fig.show()

Download and parse TREC-COVID data

Your first step to improve the cord19 search application.

This is the first on a series of blog posts that will show you how to improve a text search application, from downloading data to fine-tuning BERT models.

You can also run the steps contained here from Google Colab.

The team behind vespa.ai have built and open-sourced a CORD-19 search engine. Thanks to advanced Vespa features such as Approximate Nearest Neighbors Search and Tranformers support via ONNX it comes with the most advanced NLP methodology applied to search that is currently available.

Our first step is to download relevance judgments to be able to evaluate current query models deployed in the application and to train better ones to replace those already there.

Download the data

The files used in this section can be found at https://ir.nist.gov/covidSubmit/data.html. We will download both the topics and the relevance judgements data. Do not worry about what they are just yet, we will explore them soon.

!wget https://ir.nist.gov/covidSubmit/data/topics-rnd5.xml
!wget https://ir.nist.gov/covidSubmit/data/qrels-covid_d5_j0.5-5.txt

Parse the data

Topics

The topics file is in XML format. We can parse it and store in a dictionary called topics. We want to extract a query, a question and a narrative from each topic.

import xml.etree.ElementTree as ET

topics = {}
root = ET.parse("topics-rnd5.xml").getroot()
for topic in root.findall("topic"):
    topic_number = topic.attrib["number"]
    topics[topic_number] = {}
    for query in topic.findall("query"):
        topics[topic_number]["query"] = query.text
    for question in topic.findall("question"):
        topics[topic_number]["question"] = question.text        
    for narrative in topic.findall("narrative"):
        topics[topic_number]["narrative"] = narrative.text

There are a total of 50 topics. For example, we can see the first topic below:

topics["1"]
{'query': 'coronavirus origin',
 'question': 'what is the origin of COVID-19',
 'narrative': "seeking range of information about the SARS-CoV-2 virus's origin, including its evolution, animal source, and first transmission into humans"}

Each topic has many relevance judgements associated with them.

Relevance judgements

We can load the relevance judgement data directly into a pandas DataFrame.

import pandas as pd

relevance_data = pd.read_csv("qrels-covid_d5_j0.5-5.txt", sep=" ", header=None)
relevance_data.columns = ["topic_id", "round_id", "cord_uid", "relevancy"]

The relevance data contains all the relevance judgements made throughout the 5 rounds of the competition. relevancy equals to 0 is irrelevant, 1 is relevant and 2 is highly relevant.

relevance_data.head()
png

We are going to remove two rows that have relevancy equal to -1, which I am assuming is an error.

relevance_data[relevance_data.relevancy == -1]
png
relevance_data = relevance_data[relevance_data.relevancy >= 0]

The plot below shows that there are quite a few relevance judgments for each topic and the number of relevant documents varies quite a lot across topics.

import plotly.express as px

fig = px.histogram(relevance_data, x="topic_id", color = "relevancy")
fig.show()

Next we will discuss how we can use this data to evaluate and improve cord19 search app.