Source code for util.text_similarity.max_independent_set_calc_test

import dataclasses
import datetime
import math
import random
import time

import numpy as np
import pytest
from matplotlib import pyplot as plt

from util.lin_reg_plot_helper import LinRegPredictor, plot_regression
from util.text_similarity.max_independent_set_calc import (
    ApproximateIndependentSetCalc,
    GreedyIndependentSetCalc,
    MaxIndependentSetCalc,
    OptimalIndependentSetCalc,
)

PAIRS_LIN_LOG_FUNC = lambda x: x * math.log(x, 2)


[docs] class CalcNumPairs:
[docs] def get_num_pairs(self, num_texts) -> int: pass
[docs] class StaticCalcNumPairs(CalcNumPairs):
[docs] def __init__(self, num_pairs): self.num_pairs = num_pairs
[docs] def get_num_pairs(self, num_texts): return self.num_pairs
[docs] class LogCalcNumPairs(CalcNumPairs):
[docs] def get_num_pairs(self, num_texts): return num_texts * math.log(num_texts, 2)
[docs] class LinSquareRootCalcNumPairs(CalcNumPairs):
[docs] def get_num_pairs(self, num_texts): return int(num_texts * math.sqrt(num_texts))
[docs] @dataclasses.dataclass class RandomGraphGenerator: _calc_num_pairs: CalcNumPairs = LogCalcNumPairs()
[docs] def num_pairs(self, num_texts) -> int: return self._calc_num_pairs.get_num_pairs(num_texts)
def _random_node(self, num_texts): return random.randint(0, num_texts - 1) # noinspection PyTypeChecker def _create_random_pair(self, num_texts) -> tuple[int, int]: return tuple( sorted([self._random_node(num_texts), self._random_node(num_texts)]) )
[docs] def create_random_pairs(self, num_texts: int): random_pairs = set() while len(random_pairs) < self.num_pairs(num_texts): random_pairs.add(self._create_random_pair(num_texts)) return list(random_pairs)
[docs] @dataclasses.dataclass class MeasurementResult: num_texts: int time: float result_size: int
[docs] @staticmethod def zero(): return MeasurementResult(0, 0, 0)
def __add__(self, other): return MeasurementResult( self.num_texts + other.num_texts, self.time + other.time, self.result_size + other.result_size, ) def __radd__(self, other): return self.__add__(other) def __truediv__(self, other): return MeasurementResult( self.num_texts / other, self.time / other, self.result_size / other )
[docs] @dataclasses.dataclass class MeasureIndependentSetCalc: calc: MaxIndependentSetCalc graph_generator: RandomGraphGenerator
[docs] def analyze_calc_run(self, num_texts) -> MeasurementResult: pairs = self.graph_generator.create_random_pairs(num_texts) start = time.time() set_of_nodes = self.calc.find_max_set(num_texts, pairs) return MeasurementResult(num_texts, time.time() - start, len(set_of_nodes))
[docs] def analyze_calc_runs(self, num_texts, iterations) -> MeasurementResult: return ( sum( [self.analyze_calc_run(num_texts) for _ in range(iterations)], MeasurementResult.zero(), ) / iterations )
[docs] def generate_measurements(self, list_num_texts, iterations): return [ self.analyze_calc_runs(num_texts, iterations) for num_texts in list_num_texts ]
[docs] @pytest.fixture def greedy_calc(): return GreedyIndependentSetCalc()
[docs] @pytest.fixture def approx_calc(): return ApproximateIndependentSetCalc()
[docs] @pytest.fixture def optimal_calc(): return OptimalIndependentSetCalc()
[docs] @pytest.fixture def random_graph_gen(): return RandomGraphGenerator()
[docs] @pytest.fixture def dense_graph_gen(): return RandomGraphGenerator(_calc_num_pairs=LinSquareRootCalcNumPairs())
[docs] @pytest.mark.parametrize( "num_texts, similar_pairs, acceptable_solutions", [ # Case 1: No edges (all nodes are independent) (5, [], [{0, 1, 2, 3, 4}]), # Case 2: Fully connected graph (any one node can be in the set) (4, [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], [{0}, {1}, {2}, {3}]), # Case 3: Single vertex (1, [], [{0}]), # Case 4: Pair of vertices with one edge (2, [(0, 1)], [{0}, {1}]), # Case 5 (5, [(0, 1), (0, 2), (0, 3), (0, 4)], [{1, 2, 3, 4}]), ], ) def test_find_max_set(num_texts, similar_pairs, acceptable_solutions): calc = GreedyIndependentSetCalc() result = calc.find_max_set(num_texts, similar_pairs) # Assert result is a subset of any acceptable solution assert result in acceptable_solutions
[docs] def test_time_complexity(greedy_calc, dense_graph_gen): mis_calc = MeasureIndependentSetCalc(greedy_calc, dense_graph_gen) measurements = mis_calc.generate_measurements([10, 50, 100, 250], 1) xs = np.array([0] + [m.num_texts for m in measurements]) ys = np.array([0] + [m.time for m in measurements]) coeffs = np.polyfit(xs, ys, 2) plot_regression(xs, ys, coeffs) predictor = LinRegPredictor(coeffs) prediction_10k = datetime.timedelta(seconds=predictor.predict(10_000)) prediction_100k = datetime.timedelta(seconds=predictor.predict(100_000)) print(f"The greedy algorithm will take {str(prediction_10k)[:8]} for 10k texts") print(f"The greedy algorithm will take {str(prediction_100k)[:8]} for 100k texts") assert prediction_10k < datetime.timedelta(minutes=5) assert prediction_100k < datetime.timedelta(hours=6) assert prediction_100k < prediction_10k * 100 * 2
[docs] @dataclasses.dataclass class ScatterData: xs: np.array ys: np.array label: str
[docs] def plot_algos(tile, x_label, y_label, scatter_data_list): for scatter_data in scatter_data_list: plt.scatter(scatter_data.xs, scatter_data.ys, label=scatter_data.label) plt.xlabel(x_label) plt.ylabel(y_label) plt.title(tile) plt.legend() plt.grid(True) plt.show()
[docs] def test_greedy_accuracy(greedy_calc, optimal_calc, approx_calc, dense_graph_gen): greedy_mis_calc = MeasureIndependentSetCalc(greedy_calc, dense_graph_gen) optimal_mis_calc = MeasureIndependentSetCalc(optimal_calc, dense_graph_gen) xs = [0] + list(range(10, 41, 15)) greedy_measurements = greedy_mis_calc.generate_measurements(xs, 10) optimal_measurements = optimal_mis_calc.generate_measurements(xs, 2) optimal_ys = np.array([m.result_size for m in optimal_measurements]) greedy_ys = np.array([m.result_size for m in greedy_measurements]) plot_algos( "Comparison of greedy, optimal and approximate algorithms", "Number of texts", "Independent Set Size", [ ScatterData(xs, greedy_ys, "Greedy"), ScatterData(xs, optimal_ys, "Optimal"), ], ) assert greedy_ys[-1] > optimal_ys[-1] * 0.25