Source code for random_collections.random_collection_table

from enum import Enum
from typing import Generic, TypeVar, Dict, List

from src.random_collections.random_collection import RandomCollection, RandomCollectionBuilder
from src.random_collections.random_collection_interface import IRandom

K = TypeVar('K', bound=Enum)
V = TypeVar('V', bound=Enum)


[docs] class RandomTable(Generic[K, V], IRandom[V]): """ Table mapping keys to RandomCollections for weighted random value selection. Attributes: value_weight_dict (Dict[K, RandomCollection[V]]): Maps each key to a RandomCollection of values. """
[docs] def __init__(self, value_weight_dict: Dict[K, RandomCollection[V]]): """ Initialize RandomTable with a mapping from keys to collections. Args: value_weight_dict (Dict[K, RandomCollection[V]]): Prebuilt collections per key. """ self.value_weight_dict: Dict[K, RandomCollection[V]] = value_weight_dict
[docs] def get_random_value(self, key: K) -> V: """ Select a random value from the collection associated with `key`. Args: key (K): The key for which to sample a value. Returns: V: Randomly selected enum member. Raises: KeyError: If `key` is not in the table. """ collection = self.value_weight_dict[key] return collection.get_random_value()
[docs] class RandomTableBuilder: """ Factory for constructing RandomTable instances from weight definitions. """
[docs] @staticmethod def build_from_weight_table( key_enum: type[Enum], value_enum: type[Enum], weights: List[List[float]] ) -> RandomTable[Enum, Enum]: """ Build a RandomTable given parallel enums and a weight matrix. Args: key_enum (type[Enum]): Enum class for table keys. value_enum (type[Enum]): Enum class for table values. weights (List[List[float]]): Weight lists per key. Returns: RandomTable[Enum, Enum]: Table sampling values by key. """ keys = list(key_enum) return RandomTable( { key: RandomCollection.Builder.build_from_value_weight_dict( {v: w for v, w in zip(list(value_enum), weight_row)} ) for key, weight_row in zip(keys, weights) } )
[docs] @staticmethod def validate_value_weight_dict( key_enum: type[K], value_enum: type[V], value_weight_dict: Dict[K, Dict[V, float]], ) -> None: """ Ensure provided dict covers all enum members exactly. Args: key_enum (type[K]): Enum of expected keys. value_enum (type[V]): Enum of expected values. value_weight_dict (Dict[K, Dict[V, float]]): Mapping from keys to value-weight dicts. Raises: AssertionError: If keys or value sets don't match enums. """ assert set(key_enum) == set(value_weight_dict.keys()), ( "Keys in value_weight_dict must match key_enum" ) for key in key_enum: assert set(value_enum) == set(value_weight_dict[key].keys()), ( "Values in value_weight_dict must match value_enum for each key" )
[docs] @staticmethod def build_from_dict( key_enum: type[K], value_enum: type[V], value_weight_dict: Dict[K, Dict[V, float]], ) -> RandomTable[K, V]: """ Construct a RandomTable from nested dict of weights. Args: key_enum (type[K]): Enum class for keys. value_enum (type[V]): Enum class for values. value_weight_dict (Dict[K, Dict[V, float]]): Outer key-> inner value-> weight dict. Returns: RandomTable[K, V]: Table ready for random sampling. """ RandomTableBuilder.validate_value_weight_dict( key_enum, value_enum, value_weight_dict ) return RandomTable( { key: RandomCollectionBuilder.build_from_value_weight_dict( value_weight_dict[key] ) for key in key_enum } )