from datetime import timedelta
import nltk
import pytest
from util.lin_reg_plot_helper import LinRegPredictor, plot_regression
from util.text_similarity.lsh_min_hash.lsh_min_hash import LSHMinHash
from util.text_similarity.lsh_min_hash.shingles_generator import (
MultipleShinglesGenerator,
)
from util.text_similarity.lsh_min_hash.time_analysis import (
MeasureLSHTimeComplexity,
MeasurementParams,
RandomTextsGenerator,
)
nltk.download("words")
[docs]
@pytest.mark.parametrize(
"texts,expected_pairs",
[
# Completely identical texts
(["abc abc abc", "abc abc abc"], [(0, 1)]),
# Very similar texts with minor differences
(["The quick brown fox jumps", "The quick brown fox leaps"], [(0, 1)]),
# Very similar texts with minor differences
(
[
"I like to learn for my university courses",
"My university is big. I like to learn for my university courses",
],
[(0, 1)],
),
(
[
"I like to learn for my university courses",
"This text is much longer, but has many of the same words! My university is big. I like to learn for my university courses",
],
[],
),
# Completely different texts
(["apple pie", "orange juice", "banana split"], []),
# Multiple near duplicates
(
[
"hello world",
"hello world!",
"hello world.",
"this is different!",
],
[(0, 1), (0, 2), (1, 2)],
),
# Edge case - single text
(["lorem ipsum dolor sit amet"], []),
# Empty list
([], []),
],
)
def test_get_similar_pairs(texts, expected_pairs):
lsh = LSHMinHash(
threshold=0.6, shingles_generator=MultipleShinglesGenerator(ngram_sizes=[3, 5])
)
result = lsh.get_similar_pairs(texts)
assert sorted(result) == sorted(expected_pairs)
[docs]
@pytest.mark.parametrize(
"threshold,texts,expected_pairs",
[
# Test with different thresholds
(0.9, ["The cat in the hat", "The cat in the bag"], []),
(0.5, ["The cat in the hat", "The cat in the bag"], [(0, 1)]),
],
)
def test_threshold_variation(threshold, texts, expected_pairs):
lsh = LSHMinHash(threshold=threshold)
result = lsh.get_similar_pairs(texts)
assert sorted(result) == sorted(expected_pairs)
[docs]
def test_large_text_set():
num_texts = 30
same_index_1 = 10
same_index_2 = 15
random_texts = RandomTextsGenerator(average_words=10)
texts = random_texts.generate_random_texts(num_texts)
texts[same_index_1] = texts[same_index_2]
lsh = LSHMinHash(threshold=0.9)
result = lsh.get_similar_pairs(texts)
assert (same_index_1, same_index_2) in result
[docs]
@pytest.mark.expensive
def test_absolute_time_should_be_small():
time_calc = MeasureLSHTimeComplexity(
lsh=LSHMinHash(
threshold=0.6,
shingles_generator=MultipleShinglesGenerator(ngram_sizes=[3, 5]),
num_perm=128,
),
random_text_generator=RandomTextsGenerator(
average_words=25, number_total_available_words=4
),
)
# noinspection PyArgumentEqualDefault
coeffs, x, y = time_calc.get_poly_coeffs(
MeasurementParams(start=10, factor=5, num_points=2, iterations=1)
)
print(f"coeffs: {[f'{coeff:.4f}' for coeff in coeffs]}")
plot_regression(x, y, coeffs)
prediction_x = 1e5
prediction_y = LinRegPredictor(coeffs).predict(prediction_x)
amount_time = timedelta(seconds=prediction_y)
print(f"Prediction for {prediction_x} texts: {amount_time}")
assert amount_time < timedelta(hours=6)