Batched BM25 search in PySpark

177 Views Asked by At

I have a large dataset of documents (average length of 35 words). I want to find the top k nearest neighbors of all these documents by using BM25. Every document needs to be compared with every other document in this dataset and the top k neighbors have to be populated. I need this to be parallelised because the dataset size ranges to about 40M.

I tried using multiprocessing but the compute of a server is the bottleneck since it cannot scale into multiple nodes like pyspark can.

I'm using this code to do this but it gets stuck indefinitely when I try to show the df

import math
from collections import Counter

def calculate_bm(row, documents):
    row_dict = row.asDict()
    query = row_dict["sentence"]
    documents = documents
    document_count = len(documents)
    avg_document_length = sum(len(doc) for doc in documents) / document_count
    
    term_counts = Counter()
    for document in documents:
        term_counts.update(document)

    k1 = 1.2
    b = 0.75

    document_scores = []
    for document in documents:
        score = 0.0
        document_length = len(document)
        query_terms = Counter(query)

        for term in query_terms:
            if term not in documents:
                continue
            
            document_with_term_count = term_counts[term]
            idf = math.log((document_count - document_with_term_count + 0.5) / (document_with_term_count + 0.5))
    
            term_frequency = document.count(term)
            numerator = term_frequency * (k1 + 1)
            denominator = term_frequency + k1 * (1 - b + b * (document_length / avg_document_length))
            score += idf * (numerator / denominator)

        document_scores.append((document, score))
    
    ranked_documents = sorted(document_scores, key=lambda x: x[1], reverse=True)
    
    for idx, entry in enumerate(ranked_documents):
        row_dict[f"NN_{idx}"] = entry[0]
        row_dict[f"D_{idx}"] = entry[1]
    
    newrow = Row(**row_dict)
    return newrow

all_documents = temp_df.select(F.collect_list('sentence')).first()[0]

dist = df.rdd.map(lambda row: calculate_bm(row, all_documents))
dist_df = sqlContext.createDataFrame(dist)

My df looks like this (40M rows, ~60 columns)

+-----+---------------------------------------------------+
| idx |                     sentence                      |
+-----+---------------------------------------------------+
|  1  | [column 1 name, value, column 2 name, value, ...] |
|  2  | [column 1 name, value, column 2 name, value, ...] |
|  3  | [column 1 name, value, column 2 name, value, ...] |
|  4  | [column 1 name, value, column 2 name, value, ...] |
+-----+---------------------------------------------------+

The output is wish is this:

+--------------------+----------------+
| idx |  neighbours  |    distances   |
+--------------------+----------------+
|  1  | [2,4,61,...] | [0.01,0.02,...]|
+--------------------+----------------+

I was able to find this implementation but it's in scala and sadly I'm not able to translate it to pyspark.

1

There are 1 best solutions below

0
darked89 On

This is not the answer, just few suggestions:

  1. split your calculate_bm function

Things like number of tokens in a sentence or frequency counts of tokens in the whole data sets imho should be out an not bundled with most problematic /CPU intense BM25

  1. tokens as integers No idea if you have it done already, but your df:
|  1  | [column 1 name, value, column 2 name, value, ...] |

is a bit cryptic for me. Since it is not [token_a, token_b, ...] it must be mixing up different data types.

instead of just giving tokens sequential numbers one can collect all the tokens from a set of sentences (documents in your code?), compute the frequencies in the whole data set, then sort in the descending order getting overall token frequency data frame:

┌──────────┬────────┬───────┐
│ token_id ┆ tokens ┆ count │
│ ---      ┆ ---    ┆ ---   │
│ u32      ┆ str    ┆ u32   │
╞══════════╪════════╪═══════╡
│ 0        ┆ the    ┆ 3193  │
│ 1        ┆ of     ┆ 1850  │
│ 2        ┆ and    ┆ 1317  │
│ 3        ┆ in     ┆ 1304  │
│ …        ┆ …      ┆ …     │
│ 6        ┆ et     ┆ 581   │

Then the original string tokens can be translated to token_id and sorted producing:

┌─────────────┬──────────────────┐
│ sentence_nr ┆ tokens           │
│ ---         ┆ ---              │
│ u32         ┆ list[u32]        │
╞═════════════╪══════════════════╡
│ 0           ┆ [0, 1, … 524]    │
│ 1           ┆ [0, 0, … 3132]   │
│ 2           ┆ [0, 1, … 4701]   │
│ 3           ┆ [0, 9, … 5150]   │

Less RAM used, faster look-ups compared to strings.

  1. save the intermediate data frames

If possible I would try to use 'parquet' format readable by 'spark' etc.

That way you can skip steps already done/verify that these make sense.

Hope it helps a bit.