import dataclasses
import datetime
import math
import random
import time
import numpy as np
import pytest
from matplotlib import pyplot as plt
from util.lin_reg_plot_helper import LinRegPredictor, plot_regression
from util.text_similarity.max_independent_set_calc import (
ApproximateIndependentSetCalc,
GreedyIndependentSetCalc,
MaxIndependentSetCalc,
OptimalIndependentSetCalc,
)
PAIRS_LIN_LOG_FUNC = lambda x: x * math.log(x, 2)
[docs]
class CalcNumPairs:
[docs]
def get_num_pairs(self, num_texts) -> int:
pass
[docs]
class StaticCalcNumPairs(CalcNumPairs):
[docs]
def __init__(self, num_pairs):
self.num_pairs = num_pairs
[docs]
def get_num_pairs(self, num_texts):
return self.num_pairs
[docs]
class LogCalcNumPairs(CalcNumPairs):
[docs]
def get_num_pairs(self, num_texts):
return num_texts * math.log(num_texts, 2)
[docs]
class LinSquareRootCalcNumPairs(CalcNumPairs):
[docs]
def get_num_pairs(self, num_texts):
return int(num_texts * math.sqrt(num_texts))
[docs]
@dataclasses.dataclass
class RandomGraphGenerator:
_calc_num_pairs: CalcNumPairs = LogCalcNumPairs()
[docs]
def num_pairs(self, num_texts) -> int:
return self._calc_num_pairs.get_num_pairs(num_texts)
def _random_node(self, num_texts):
return random.randint(0, num_texts - 1)
# noinspection PyTypeChecker
def _create_random_pair(self, num_texts) -> tuple[int, int]:
return tuple(
sorted([self._random_node(num_texts), self._random_node(num_texts)])
)
[docs]
def create_random_pairs(self, num_texts: int):
random_pairs = set()
while len(random_pairs) < self.num_pairs(num_texts):
random_pairs.add(self._create_random_pair(num_texts))
return list(random_pairs)
[docs]
@dataclasses.dataclass
class MeasurementResult:
num_texts: int
time: float
result_size: int
[docs]
@staticmethod
def zero():
return MeasurementResult(0, 0, 0)
def __add__(self, other):
return MeasurementResult(
self.num_texts + other.num_texts,
self.time + other.time,
self.result_size + other.result_size,
)
def __radd__(self, other):
return self.__add__(other)
def __truediv__(self, other):
return MeasurementResult(
self.num_texts / other, self.time / other, self.result_size / other
)
[docs]
@dataclasses.dataclass
class MeasureIndependentSetCalc:
calc: MaxIndependentSetCalc
graph_generator: RandomGraphGenerator
[docs]
def analyze_calc_run(self, num_texts) -> MeasurementResult:
pairs = self.graph_generator.create_random_pairs(num_texts)
start = time.time()
set_of_nodes = self.calc.find_max_set(num_texts, pairs)
return MeasurementResult(num_texts, time.time() - start, len(set_of_nodes))
[docs]
def analyze_calc_runs(self, num_texts, iterations) -> MeasurementResult:
return (
sum(
[self.analyze_calc_run(num_texts) for _ in range(iterations)],
MeasurementResult.zero(),
)
/ iterations
)
[docs]
def generate_measurements(self, list_num_texts, iterations):
return [
self.analyze_calc_runs(num_texts, iterations)
for num_texts in list_num_texts
]
[docs]
@pytest.fixture
def greedy_calc():
return GreedyIndependentSetCalc()
[docs]
@pytest.fixture
def approx_calc():
return ApproximateIndependentSetCalc()
[docs]
@pytest.fixture
def optimal_calc():
return OptimalIndependentSetCalc()
[docs]
@pytest.fixture
def random_graph_gen():
return RandomGraphGenerator()
[docs]
@pytest.fixture
def dense_graph_gen():
return RandomGraphGenerator(_calc_num_pairs=LinSquareRootCalcNumPairs())
[docs]
@pytest.mark.parametrize(
"num_texts, similar_pairs, acceptable_solutions",
[
# Case 1: No edges (all nodes are independent)
(5, [], [{0, 1, 2, 3, 4}]),
# Case 2: Fully connected graph (any one node can be in the set)
(4, [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], [{0}, {1}, {2}, {3}]),
# Case 3: Single vertex
(1, [], [{0}]),
# Case 4: Pair of vertices with one edge
(2, [(0, 1)], [{0}, {1}]),
# Case 5
(5, [(0, 1), (0, 2), (0, 3), (0, 4)], [{1, 2, 3, 4}]),
],
)
def test_find_max_set(num_texts, similar_pairs, acceptable_solutions):
calc = GreedyIndependentSetCalc()
result = calc.find_max_set(num_texts, similar_pairs)
# Assert result is a subset of any acceptable solution
assert result in acceptable_solutions
[docs]
def test_time_complexity(greedy_calc, dense_graph_gen):
mis_calc = MeasureIndependentSetCalc(greedy_calc, dense_graph_gen)
measurements = mis_calc.generate_measurements([10, 50, 100, 250], 1)
xs = np.array([0] + [m.num_texts for m in measurements])
ys = np.array([0] + [m.time for m in measurements])
coeffs = np.polyfit(xs, ys, 2)
plot_regression(xs, ys, coeffs)
predictor = LinRegPredictor(coeffs)
prediction_10k = datetime.timedelta(seconds=predictor.predict(10_000))
prediction_100k = datetime.timedelta(seconds=predictor.predict(100_000))
print(f"The greedy algorithm will take {str(prediction_10k)[:8]} for 10k texts")
print(f"The greedy algorithm will take {str(prediction_100k)[:8]} for 100k texts")
assert prediction_10k < datetime.timedelta(minutes=5)
assert prediction_100k < datetime.timedelta(hours=6)
assert prediction_100k < prediction_10k * 100 * 2
[docs]
@dataclasses.dataclass
class ScatterData:
xs: np.array
ys: np.array
label: str
[docs]
def plot_algos(tile, x_label, y_label, scatter_data_list):
for scatter_data in scatter_data_list:
plt.scatter(scatter_data.xs, scatter_data.ys, label=scatter_data.label)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(tile)
plt.legend()
plt.grid(True)
plt.show()
[docs]
def test_greedy_accuracy(greedy_calc, optimal_calc, approx_calc, dense_graph_gen):
greedy_mis_calc = MeasureIndependentSetCalc(greedy_calc, dense_graph_gen)
optimal_mis_calc = MeasureIndependentSetCalc(optimal_calc, dense_graph_gen)
xs = [0] + list(range(10, 41, 15))
greedy_measurements = greedy_mis_calc.generate_measurements(xs, 10)
optimal_measurements = optimal_mis_calc.generate_measurements(xs, 2)
optimal_ys = np.array([m.result_size for m in optimal_measurements])
greedy_ys = np.array([m.result_size for m in greedy_measurements])
plot_algos(
"Comparison of greedy, optimal and approximate algorithms",
"Number of texts",
"Independent Set Size",
[
ScatterData(xs, greedy_ys, "Greedy"),
ScatterData(xs, optimal_ys, "Optimal"),
],
)
assert greedy_ys[-1] > optimal_ys[-1] * 0.25