"""
State functions for modeling discrete states over time.
A state function represents a resource that can be in different states
over time, with optional transition constraints between states.
Use cases:
- Machine modes (setup, processing, idle, maintenance)
- Room configurations (lecture, exam, meeting)
- Worker skills/roles
- Any resource with discrete, mutually exclusive states
Example:
>>> machine_state = StateFunction(name="machine")
>>> # Task requires machine in state 1
>>> satisfy(always_equal(machine_state, task, 1))
>>> # Define valid transitions with durations
>>> transitions = TransitionMatrix([
... [0, 5, 10], # From state 0: 0->0=0, 0->1=5, 0->2=10
... [5, 0, 3], # From state 1: 1->0=5, 1->1=0, 1->2=3
... [10, 3, 0], # From state 2: 2->0=10, 2->1=3, 2->2=0
... ])
>>> machine_state = StateFunction(name="machine", transitions=transitions)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Sequence
if TYPE_CHECKING:
from pycsp3_scheduling.variables.interval import IntervalVar
# =============================================================================
# Transition Matrix
# =============================================================================
[docs]
@dataclass
class TransitionMatrix:
"""
Transition matrix defining valid state transitions and durations.
A transition matrix specifies the time required to transition from
one state to another. A value of -1 (or FORBIDDEN) indicates that
the transition is not allowed.
Attributes:
matrix: 2D list of transition times. matrix[i][j] is the time
to transition from state i to state j.
name: Optional name for the matrix.
Example:
>>> # 3 states with symmetric transition times
>>> tm = TransitionMatrix([
... [0, 5, 10],
... [5, 0, 3],
... [10, 3, 0],
... ])
>>> tm[0, 1] # Time from state 0 to state 1
5
"""
matrix: list[list[int]]
name: str | None = None
_id: int = field(default=-1, repr=False)
# Special value indicating forbidden transition
FORBIDDEN: int = -1
[docs]
def __post_init__(self) -> None:
"""Validate and assign unique ID."""
self._validate()
if self._id == -1:
self._id = TransitionMatrix._get_next_id()
def _validate(self) -> None:
"""Validate the transition matrix."""
if not self.matrix:
raise ValueError("Transition matrix cannot be empty")
n = len(self.matrix)
for i, row in enumerate(self.matrix):
if len(row) != n:
raise ValueError(
f"Transition matrix must be square. "
f"Row {i} has {len(row)} elements, expected {n}"
)
for j, val in enumerate(row):
if not isinstance(val, int):
raise TypeError(
f"Transition matrix values must be integers, "
f"got {type(val).__name__} at [{i}][{j}]"
)
@staticmethod
def _get_next_id() -> int:
"""Get next unique ID."""
current = getattr(TransitionMatrix, "_id_counter", 0)
TransitionMatrix._id_counter = current + 1
return current
@property
def size(self) -> int:
"""Number of states (dimension of the matrix)."""
return len(self.matrix)
[docs]
def __getitem__(self, key: tuple[int, int]) -> int:
"""Get transition time from state i to state j."""
i, j = key
return self.matrix[i][j]
[docs]
def __setitem__(self, key: tuple[int, int], value: int) -> None:
"""Set transition time from state i to state j."""
i, j = key
self.matrix[i][j] = value
[docs]
def is_forbidden(self, from_state: int, to_state: int) -> bool:
"""Check if transition from from_state to to_state is forbidden."""
return self.matrix[from_state][to_state] == self.FORBIDDEN
[docs]
def get_row(self, state: int) -> list[int]:
"""Get all transition times from a given state."""
return self.matrix[state]
[docs]
def get_column(self, state: int) -> list[int]:
"""Get all transition times to a given state."""
return [row[state] for row in self.matrix]
[docs]
def __repr__(self) -> str:
"""String representation."""
if self.name:
return f"TransitionMatrix({self.name}, {self.size}x{self.size})"
return f"TransitionMatrix({self.size}x{self.size})"
# =============================================================================
# State Function
# =============================================================================
[docs]
@dataclass
class StateFunction:
"""
State function representing a discrete state over time.
A state function can be in different integer states at different times.
Tasks can require specific states during their execution, and transitions
between states can have associated times defined by a transition matrix.
Attributes:
name: Name of the state function.
transitions: Optional transition matrix defining transition times.
initial_state: Initial state at time 0 (default: no specific state).
states: Set of valid state values (inferred from transitions if not given).
Example:
>>> machine = StateFunction(name="machine_mode")
>>> # Machine must be in state 2 during task execution
>>> satisfy(always_equal(machine, task, 2))
"""
name: str
transitions: TransitionMatrix | None = None
initial_state: int | None = None
states: set[int] | None = None
_id: int = field(default=-1, repr=False)
[docs]
def __post_init__(self) -> None:
"""Initialize and validate."""
if self._id == -1:
self._id = StateFunction._get_next_id()
_register_state_function(self)
# Infer states from transition matrix if not provided
if self.states is None and self.transitions is not None:
self.states = set(range(self.transitions.size))
@staticmethod
def _get_next_id() -> int:
"""Get next unique ID."""
current = getattr(StateFunction, "_id_counter", 0)
StateFunction._id_counter = current + 1
return current
@property
def num_states(self) -> int | None:
"""Number of valid states, if known."""
if self.states is not None:
return len(self.states)
if self.transitions is not None:
return self.transitions.size
return None
[docs]
def __hash__(self) -> int:
"""Hash based on unique ID."""
return hash(self._id)
[docs]
def __repr__(self) -> str:
"""String representation."""
parts = [f"StateFunction({self.name!r}"]
if self.transitions:
parts.append(f", transitions={self.transitions.size}x{self.transitions.size}")
if self.initial_state is not None:
parts.append(f", initial={self.initial_state}")
parts.append(")")
return "".join(parts)
# =============================================================================
# State Constraint Types
# =============================================================================
class StateConstraintType(Enum):
"""Types of state constraints."""
ALWAYS_IN = auto() # State in range [min, max]
ALWAYS_EQUAL = auto() # State equals specific value
ALWAYS_CONSTANT = auto() # State doesn't change during interval
ALWAYS_NO_STATE = auto() # No state defined during interval
@dataclass
class StateConstraint:
"""
Constraint on a state function during an interval.
Attributes:
state_func: The state function being constrained.
interval: The interval during which the constraint applies.
constraint_type: Type of constraint.
value: State value for ALWAYS_EQUAL.
min_value: Minimum state for ALWAYS_IN.
max_value: Maximum state for ALWAYS_IN.
is_start_aligned: Whether constraint starts exactly at interval start.
is_end_aligned: Whether constraint ends exactly at interval end.
"""
state_func: StateFunction
interval: IntervalVar
constraint_type: StateConstraintType
value: int | None = None
min_value: int | None = None
max_value: int | None = None
is_start_aligned: bool = True
is_end_aligned: bool = True
def __repr__(self) -> str:
"""String representation."""
interval_name = self.interval.name if self.interval else "?"
if self.constraint_type == StateConstraintType.ALWAYS_EQUAL:
return f"always_equal({self.state_func.name}, {interval_name}, {self.value})"
elif self.constraint_type == StateConstraintType.ALWAYS_IN:
return f"always_in({self.state_func.name}, {interval_name}, {self.min_value}, {self.max_value})"
elif self.constraint_type == StateConstraintType.ALWAYS_CONSTANT:
return f"always_constant({self.state_func.name}, {interval_name})"
elif self.constraint_type == StateConstraintType.ALWAYS_NO_STATE:
return f"always_no_state({self.state_func.name}, {interval_name})"
return f"StateConstraint({self.constraint_type})"
# =============================================================================
# State Constraint Functions
# =============================================================================
[docs]
def always_equal(
state_func: StateFunction,
interval: IntervalVar,
value: int,
is_start_aligned: bool = True,
is_end_aligned: bool = True,
) -> StateConstraint:
"""
Constrain state function to equal a specific value during interval.
The state function must be equal to the specified value throughout
the execution of the interval.
Args:
state_func: The state function.
interval: The interval during which the constraint applies.
value: The required state value.
is_start_aligned: If True, state must equal value exactly at start.
is_end_aligned: If True, state must equal value exactly at end.
Returns:
A StateConstraint representing the always_equal constraint.
Example:
>>> machine = StateFunction(name="machine")
>>> # Machine must be in state 2 during task
>>> satisfy(always_equal(machine, task, 2))
"""
from pycsp3_scheduling.variables.interval import IntervalVar
if not isinstance(state_func, StateFunction):
raise TypeError(
f"state_func must be a StateFunction, got {type(state_func).__name__}"
)
if not isinstance(interval, IntervalVar):
raise TypeError(
f"interval must be an IntervalVar, got {type(interval).__name__}"
)
if not isinstance(value, int):
raise TypeError(f"value must be an int, got {type(value).__name__}")
return StateConstraint(
state_func=state_func,
interval=interval,
constraint_type=StateConstraintType.ALWAYS_EQUAL,
value=value,
is_start_aligned=is_start_aligned,
is_end_aligned=is_end_aligned,
)
[docs]
def always_in(
state_func: StateFunction,
interval: IntervalVar,
min_value: int,
max_value: int,
is_start_aligned: bool = True,
is_end_aligned: bool = True,
) -> StateConstraint:
"""
Constrain state function to be within a range during interval.
The state function must be within [min_value, max_value] throughout
the execution of the interval.
Args:
state_func: The state function.
interval: The interval during which the constraint applies.
min_value: Minimum allowed state value.
max_value: Maximum allowed state value.
is_start_aligned: If True, constraint applies exactly at start.
is_end_aligned: If True, constraint applies exactly at end.
Returns:
A StateConstraint representing the always_in constraint.
Example:
>>> machine = StateFunction(name="machine")
>>> # Machine must be in state 1, 2, or 3 during task
>>> satisfy(always_in(machine, task, 1, 3))
"""
from pycsp3_scheduling.variables.interval import IntervalVar
if not isinstance(state_func, StateFunction):
raise TypeError(
f"state_func must be a StateFunction, got {type(state_func).__name__}"
)
if not isinstance(interval, IntervalVar):
raise TypeError(
f"interval must be an IntervalVar, got {type(interval).__name__}"
)
if not isinstance(min_value, int) or not isinstance(max_value, int):
raise TypeError("min_value and max_value must be integers")
if min_value > max_value:
raise ValueError(
f"min_value ({min_value}) cannot exceed max_value ({max_value})"
)
return StateConstraint(
state_func=state_func,
interval=interval,
constraint_type=StateConstraintType.ALWAYS_IN,
min_value=min_value,
max_value=max_value,
is_start_aligned=is_start_aligned,
is_end_aligned=is_end_aligned,
)
[docs]
def always_constant(
state_func: StateFunction,
interval: IntervalVar,
is_start_aligned: bool = True,
is_end_aligned: bool = True,
) -> StateConstraint:
"""
Constrain state function to remain constant during interval.
The state function must not change its value throughout the
execution of the interval.
Args:
state_func: The state function.
interval: The interval during which the constraint applies.
is_start_aligned: If True, constant region starts exactly at start.
is_end_aligned: If True, constant region ends exactly at end.
Returns:
A StateConstraint representing the always_constant constraint.
Example:
>>> machine = StateFunction(name="machine")
>>> # Machine state cannot change during task
>>> satisfy(always_constant(machine, task))
"""
from pycsp3_scheduling.variables.interval import IntervalVar
if not isinstance(state_func, StateFunction):
raise TypeError(
f"state_func must be a StateFunction, got {type(state_func).__name__}"
)
if not isinstance(interval, IntervalVar):
raise TypeError(
f"interval must be an IntervalVar, got {type(interval).__name__}"
)
return StateConstraint(
state_func=state_func,
interval=interval,
constraint_type=StateConstraintType.ALWAYS_CONSTANT,
is_start_aligned=is_start_aligned,
is_end_aligned=is_end_aligned,
)
[docs]
def always_no_state(
state_func: StateFunction,
interval: IntervalVar,
is_start_aligned: bool = True,
is_end_aligned: bool = True,
) -> StateConstraint:
"""
Constrain state function to have no defined state during interval.
The state function must not be in any state throughout the
execution of the interval (the resource is "unused").
Args:
state_func: The state function.
interval: The interval during which the constraint applies.
is_start_aligned: If True, no-state region starts exactly at start.
is_end_aligned: If True, no-state region ends exactly at end.
Returns:
A StateConstraint representing the always_no_state constraint.
Example:
>>> machine = StateFunction(name="machine")
>>> # Machine must be unused during maintenance
>>> satisfy(always_no_state(machine, maintenance_interval))
"""
from pycsp3_scheduling.variables.interval import IntervalVar
if not isinstance(state_func, StateFunction):
raise TypeError(
f"state_func must be a StateFunction, got {type(state_func).__name__}"
)
if not isinstance(interval, IntervalVar):
raise TypeError(
f"interval must be an IntervalVar, got {type(interval).__name__}"
)
return StateConstraint(
state_func=state_func,
interval=interval,
constraint_type=StateConstraintType.ALWAYS_NO_STATE,
is_start_aligned=is_start_aligned,
is_end_aligned=is_end_aligned,
)
# =============================================================================
# Registry for State Functions
# =============================================================================
_state_function_registry: list[StateFunction] = []
def _register_state_function(sf: StateFunction) -> None:
"""Register a state function."""
if sf not in _state_function_registry:
_state_function_registry.append(sf)
def get_registered_state_functions() -> list[StateFunction]:
"""Get all registered state functions."""
return list(_state_function_registry)
def clear_state_function_registry() -> None:
"""Clear the state function registry."""
_state_function_registry.clear()
StateFunction._id_counter = 0
TransitionMatrix._id_counter = 0
# =============================================================================
# Convenience State Helpers
# =============================================================================
[docs]
def requires_state(
interval: IntervalVar,
state_func: StateFunction,
required_state: int,
) -> StateConstraint:
"""
Simplified constraint that interval requires a specific state.
This is a convenience wrapper around always_equal with a more intuitive
parameter order (interval first, like other constraint functions).
Args:
interval: The interval requiring the state.
state_func: The state function (resource).
required_state: The required state value.
Returns:
A StateConstraint representing the requirement.
Example:
>>> oven = StateFunction(name="oven_temp")
>>> bake_task = IntervalVar(size=30, name="bake")
>>> # Baking requires oven at temperature state 2 (e.g., 350F)
>>> satisfy(requires_state(bake_task, oven, 2))
"""
return always_equal(state_func, interval, required_state)
[docs]
def sets_state(
interval: IntervalVar,
state_func: StateFunction,
before_state: int | None,
after_state: int,
) -> list[StateConstraint]:
"""
Interval transitions the state from one value to another.
This constraint models a task that changes the state of a resource.
The state must be `before_state` when the interval starts (if specified),
and becomes `after_state` when the interval ends.
Args:
interval: The interval performing the state change.
state_func: The state function.
before_state: Required state before interval (None = any state).
after_state: State after interval completes.
Returns:
List of StateConstraints representing the state transition.
Example:
>>> machine_mode = StateFunction(name="machine_mode")
>>> changeover = IntervalVar(size=15, name="changeover_A_to_B")
>>> # This changeover task transitions machine from mode A (0) to mode B (1)
>>> satisfy(sets_state(changeover, machine_mode, before_state=0, after_state=1))
"""
from pycsp3_scheduling.variables.interval import IntervalVar
if not isinstance(state_func, StateFunction):
raise TypeError(
f"state_func must be a StateFunction, got {type(state_func).__name__}"
)
if not isinstance(interval, IntervalVar):
raise TypeError(
f"interval must be an IntervalVar, got {type(interval).__name__}"
)
if not isinstance(after_state, int):
raise TypeError(f"after_state must be an int, got {type(after_state).__name__}")
if before_state is not None and not isinstance(before_state, int):
raise TypeError(f"before_state must be an int or None, got {type(before_state).__name__}")
constraints = []
# If before_state is specified, require that state at start
if before_state is not None:
constraints.append(
StateConstraint(
state_func=state_func,
interval=interval,
constraint_type=StateConstraintType.ALWAYS_EQUAL,
value=before_state,
is_start_aligned=True,
is_end_aligned=False, # Only at start
)
)
# After interval, state becomes after_state
constraints.append(
StateConstraint(
state_func=state_func,
interval=interval,
constraint_type=StateConstraintType.ALWAYS_EQUAL,
value=after_state,
is_start_aligned=False, # Only at end
is_end_aligned=True,
)
)
return constraints