from enum import Enum
from typing import Generic, TypeVar, Dict, List
from src.random_collections.random_collection import RandomCollection, RandomCollectionBuilder
from src.random_collections.random_collection_interface import IRandom
K = TypeVar('K', bound=Enum)
V = TypeVar('V', bound=Enum)
[docs]
class RandomTable(Generic[K, V], IRandom[V]):
"""
Table mapping keys to RandomCollections for weighted random value selection.
Attributes:
value_weight_dict (Dict[K, RandomCollection[V]]):
Maps each key to a RandomCollection of values.
"""
[docs]
def __init__(self, value_weight_dict: Dict[K, RandomCollection[V]]):
"""
Initialize RandomTable with a mapping from keys to collections.
Args:
value_weight_dict (Dict[K, RandomCollection[V]]):
Prebuilt collections per key.
"""
self.value_weight_dict: Dict[K, RandomCollection[V]] = value_weight_dict
[docs]
def get_random_value(self, key: K) -> V:
"""
Select a random value from the collection associated with `key`.
Args:
key (K): The key for which to sample a value.
Returns:
V: Randomly selected enum member.
Raises:
KeyError: If `key` is not in the table.
"""
collection = self.value_weight_dict[key]
return collection.get_random_value()
[docs]
class RandomTableBuilder:
"""
Factory for constructing RandomTable instances from weight definitions.
"""
[docs]
@staticmethod
def build_from_weight_table(
key_enum: type[Enum],
value_enum: type[Enum],
weights: List[List[float]]
) -> RandomTable[Enum, Enum]:
"""
Build a RandomTable given parallel enums and a weight matrix.
Args:
key_enum (type[Enum]): Enum class for table keys.
value_enum (type[Enum]): Enum class for table values.
weights (List[List[float]]): Weight lists per key.
Returns:
RandomTable[Enum, Enum]: Table sampling values by key.
"""
keys = list(key_enum)
return RandomTable(
{
key: RandomCollection.Builder.build_from_value_weight_dict(
{v: w for v, w in zip(list(value_enum), weight_row)}
)
for key, weight_row in zip(keys, weights)
}
)
[docs]
@staticmethod
def validate_value_weight_dict(
key_enum: type[K],
value_enum: type[V],
value_weight_dict: Dict[K, Dict[V, float]],
) -> None:
"""
Ensure provided dict covers all enum members exactly.
Args:
key_enum (type[K]): Enum of expected keys.
value_enum (type[V]): Enum of expected values.
value_weight_dict (Dict[K, Dict[V, float]]):
Mapping from keys to value-weight dicts.
Raises:
AssertionError: If keys or value sets don't match enums.
"""
assert set(key_enum) == set(value_weight_dict.keys()), (
"Keys in value_weight_dict must match key_enum"
)
for key in key_enum:
assert set(value_enum) == set(value_weight_dict[key].keys()), (
"Values in value_weight_dict must match value_enum for each key"
)
[docs]
@staticmethod
def build_from_dict(
key_enum: type[K],
value_enum: type[V],
value_weight_dict: Dict[K, Dict[V, float]],
) -> RandomTable[K, V]:
"""
Construct a RandomTable from nested dict of weights.
Args:
key_enum (type[K]): Enum class for keys.
value_enum (type[V]): Enum class for values.
value_weight_dict (Dict[K, Dict[V, float]]):
Outer key-> inner value-> weight dict.
Returns:
RandomTable[K, V]: Table ready for random sampling.
"""
RandomTableBuilder.validate_value_weight_dict(
key_enum, value_enum, value_weight_dict
)
return RandomTable(
{
key: RandomCollectionBuilder.build_from_value_weight_dict(
value_weight_dict[key]
)
for key in key_enum
}
)