Source code for pycsp3_scheduling.expressions.interval_expr

"""
Interval expression functions for scheduling models.

These functions return expression objects that can be used in constraints
and objectives. They extract properties from interval variables.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Union

if TYPE_CHECKING:
    from pycsp3_scheduling.variables.interval import IntervalVar


class ExprType(Enum):
    """Types of interval expressions."""

    START_OF = auto()
    END_OF = auto()
    SIZE_OF = auto()
    LENGTH_OF = auto()
    PRESENCE_OF = auto()
    OVERLAP_LENGTH = auto()
    # Arithmetic combinations
    ADD = auto()
    SUB = auto()
    MUL = auto()
    DIV = auto()
    NEG = auto()
    ABS = auto()
    MIN = auto()
    MAX = auto()
    # Comparison (for constraints)
    EQ = auto()
    NE = auto()
    LT = auto()
    LE = auto()
    GT = auto()
    GE = auto()


[docs] @dataclass class IntervalExpr: """ Base class for interval-related expressions. These expressions represent values derived from interval variables that can be used in constraints and objectives. Attributes: expr_type: The type of expression. interval: The interval variable (if applicable). absent_value: Value to use when interval is absent. operands: Child expressions for compound expressions. value: Constant value (for literals). """ expr_type: ExprType interval: IntervalVar | None = None absent_value: int = 0 operands: list[IntervalExpr] = field(default_factory=list) value: int | None = None _id: int = field(default=-1, repr=False)
[docs] def __post_init__(self) -> None: """Assign unique ID.""" if self._id == -1: self._id = IntervalExpr._get_next_id()
@staticmethod def _get_next_id() -> int: """Get next unique ID.""" current = getattr(IntervalExpr, "_id_counter", 0) IntervalExpr._id_counter = current + 1 return current # Arithmetic operators
[docs] def __add__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Add two expressions or expression and constant.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.ADD, operands=[self, other_expr], )
[docs] def __radd__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Right addition.""" return self.__add__(other)
[docs] def __sub__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Subtract two expressions or expression and constant.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.SUB, operands=[self, other_expr], )
[docs] def __rsub__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Right subtraction.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.SUB, operands=[other_expr, self], )
[docs] def __mul__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Multiply two expressions or expression and constant.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.MUL, operands=[self, other_expr], )
[docs] def __rmul__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Right multiplication.""" return self.__mul__(other)
[docs] def __truediv__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Divide two expressions or expression and constant.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.DIV, operands=[self, other_expr], )
[docs] def __neg__(self) -> IntervalExpr: """Negate expression.""" return IntervalExpr( expr_type=ExprType.NEG, operands=[self], )
[docs] def __abs__(self) -> IntervalExpr: """Absolute value of expression.""" return IntervalExpr( expr_type=ExprType.ABS, operands=[self], )
# Comparison operators (return constraint expressions)
[docs] def __eq__(self, other: object) -> IntervalExpr: # type: ignore[override] """Equality comparison.""" if isinstance(other, (IntervalExpr, int)): other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.EQ, operands=[self, other_expr], ) return NotImplemented
[docs] def __ne__(self, other: object) -> IntervalExpr: # type: ignore[override] """Inequality comparison.""" if isinstance(other, (IntervalExpr, int)): other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.NE, operands=[self, other_expr], ) return NotImplemented
[docs] def __lt__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Less than comparison.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.LT, operands=[self, other_expr], )
[docs] def __le__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Less than or equal comparison.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.LE, operands=[self, other_expr], )
[docs] def __gt__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Greater than comparison.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.GT, operands=[self, other_expr], )
[docs] def __ge__(self, other: Union[IntervalExpr, int]) -> IntervalExpr: """Greater than or equal comparison.""" other_expr = _to_expr(other) return IntervalExpr( expr_type=ExprType.GE, operands=[self, other_expr], )
[docs] def __hash__(self) -> int: """Hash based on unique ID.""" return hash(self._id)
[docs] def __repr__(self) -> str: """String representation.""" # Check for constant value first if self.value is not None: return str(self.value) if self.expr_type == ExprType.START_OF: return f"start_of({self.interval.name if self.interval else '?'})" elif self.expr_type == ExprType.END_OF: return f"end_of({self.interval.name if self.interval else '?'})" elif self.expr_type == ExprType.SIZE_OF: return f"size_of({self.interval.name if self.interval else '?'})" elif self.expr_type == ExprType.LENGTH_OF: return f"length_of({self.interval.name if self.interval else '?'})" elif self.expr_type == ExprType.PRESENCE_OF: return f"presence_of({self.interval.name if self.interval else '?'})" elif self.expr_type == ExprType.OVERLAP_LENGTH: names = [op.interval.name if op.interval else '?' for op in self.operands] return f"overlap_length({names[0]}, {names[1]})" elif self.expr_type == ExprType.ADD: return f"({self.operands[0]} + {self.operands[1]})" elif self.expr_type == ExprType.SUB: return f"({self.operands[0]} - {self.operands[1]})" elif self.expr_type == ExprType.MUL: return f"({self.operands[0]} * {self.operands[1]})" elif self.expr_type == ExprType.DIV: return f"({self.operands[0]} / {self.operands[1]})" elif self.expr_type == ExprType.NEG: return f"(-{self.operands[0]})" elif self.expr_type == ExprType.MIN: return f"min({', '.join(str(op) for op in self.operands)})" elif self.expr_type == ExprType.MAX: return f"max({', '.join(str(op) for op in self.operands)})" elif self.expr_type == ExprType.ABS: return f"abs({self.operands[0]})" elif self.expr_type == ExprType.EQ: return f"({self.operands[0]} == {self.operands[1]})" elif self.expr_type == ExprType.NE: return f"({self.operands[0]} != {self.operands[1]})" elif self.expr_type == ExprType.LT: return f"({self.operands[0]} < {self.operands[1]})" elif self.expr_type == ExprType.LE: return f"({self.operands[0]} <= {self.operands[1]})" elif self.expr_type == ExprType.GT: return f"({self.operands[0]} > {self.operands[1]})" elif self.expr_type == ExprType.GE: return f"({self.operands[0]} >= {self.operands[1]})" return f"IntervalExpr({self.expr_type})"
[docs] def get_intervals(self) -> list[IntervalVar]: """Get all interval variables referenced by this expression.""" intervals = [] if self.interval is not None: intervals.append(self.interval) for operand in self.operands: intervals.extend(operand.get_intervals()) return intervals
[docs] def is_comparison(self) -> bool: """Check if this is a comparison expression (constraint).""" return self.expr_type in ( ExprType.EQ, ExprType.NE, ExprType.LT, ExprType.LE, ExprType.GT, ExprType.GE, )
def _to_expr(value: Union[IntervalExpr, int]) -> IntervalExpr: """Convert value to IntervalExpr.""" if isinstance(value, IntervalExpr): return value if isinstance(value, int): # Create a constant expression return IntervalExpr( expr_type=ExprType.ADD, # Dummy type for constants value=value, ) raise TypeError(f"Cannot convert {type(value)} to IntervalExpr") # ============================================================================ # Public API Functions # ============================================================================
[docs] def start_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr: """ Return an expression representing the start time of an interval. If the interval is absent (optional and not selected), returns absent_value. Args: interval: The interval variable. absent_value: Value to return if interval is absent (default: 0). Returns: An expression representing the start time. Example: >>> task = IntervalVar(size=10, name="task") >>> expr = start_of(task) >>> # Can be used in constraints: start_of(task) >= 5 """ return IntervalExpr( expr_type=ExprType.START_OF, interval=interval, absent_value=absent_value, )
[docs] def end_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr: """ Return an expression representing the end time of an interval. If the interval is absent (optional and not selected), returns absent_value. FIXME: end_of() still returns an internal IntervalExpr; for pycsp3 objectives use end_time() for now. Args: interval: The interval variable. absent_value: Value to return if interval is absent (default: 0). Returns: An expression representing the end time. Example: >>> task = IntervalVar(size=10, name="task") >>> expr = end_of(task) >>> # Can be used in constraints: end_of(task) <= 100 """ return IntervalExpr( expr_type=ExprType.END_OF, interval=interval, absent_value=absent_value, )
[docs] def size_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr: """ Return an expression representing the size (duration) of an interval. If the interval is absent (optional and not selected), returns absent_value. Args: interval: The interval variable. absent_value: Value to return if interval is absent (default: 0). Returns: An expression representing the size. Example: >>> task = IntervalVar(size=(5, 20), name="task") >>> expr = size_of(task) >>> # Can be used in constraints: size_of(task) >= 10 """ return IntervalExpr( expr_type=ExprType.SIZE_OF, interval=interval, absent_value=absent_value, )
[docs] def length_of(interval: IntervalVar, absent_value: int = 0) -> IntervalExpr: """ Return an expression representing the length of an interval. Length can differ from size when intensity functions are used. If the interval is absent, returns absent_value. Args: interval: The interval variable. absent_value: Value to return if interval is absent (default: 0). Returns: An expression representing the length. Example: >>> task = IntervalVar(size=10, length=(8, 12), name="task") >>> expr = length_of(task) """ return IntervalExpr( expr_type=ExprType.LENGTH_OF, interval=interval, absent_value=absent_value, )
[docs] def presence_of(interval: IntervalVar) -> IntervalExpr: """ Return a boolean expression representing whether the interval is present. For mandatory intervals, this is always true. For optional intervals, this is a decision variable. Args: interval: The interval variable. Returns: A boolean expression (0 or 1) for presence. Example: >>> task = IntervalVar(size=10, optional=True, name="task") >>> expr = presence_of(task) >>> # Can be used: presence_of(task) == 1 means task is selected """ return IntervalExpr( expr_type=ExprType.PRESENCE_OF, interval=interval, absent_value=0, # Not applicable for presence )
[docs] def overlap_length( interval1: IntervalVar, interval2: IntervalVar, absent_value: int = 0, ) -> IntervalExpr: """ Return an expression for the overlap length between two intervals. The overlap is max(0, min(end1, end2) - max(start1, start2)). If either interval is absent, returns absent_value. Args: interval1: First interval variable. interval2: Second interval variable. absent_value: Value to return if either interval is absent. Returns: An expression representing the overlap length. Example: >>> task1 = IntervalVar(size=10, name="task1") >>> task2 = IntervalVar(size=15, name="task2") >>> expr = overlap_length(task1, task2) >>> # expr == 0 means no overlap """ # Create placeholder expressions for the two intervals expr1 = IntervalExpr( expr_type=ExprType.START_OF, interval=interval1, ) expr2 = IntervalExpr( expr_type=ExprType.START_OF, interval=interval2, ) return IntervalExpr( expr_type=ExprType.OVERLAP_LENGTH, operands=[expr1, expr2], absent_value=absent_value, )
# ============================================================================ # Utility Functions # ============================================================================
[docs] def expr_min(*args: Union[IntervalExpr, int]) -> IntervalExpr: """ Return the minimum of multiple expressions. Args: *args: Expressions or integers to take minimum of. Returns: An expression representing the minimum. """ if len(args) < 2: raise ValueError("expr_min requires at least 2 arguments") exprs = [_to_expr(a) for a in args] return IntervalExpr( expr_type=ExprType.MIN, operands=exprs, )
[docs] def expr_max(*args: Union[IntervalExpr, int]) -> IntervalExpr: """ Return the maximum of multiple expressions. Args: *args: Expressions or integers to take maximum of. Returns: An expression representing the maximum. """ if len(args) < 2: raise ValueError("expr_max requires at least 2 arguments") exprs = [_to_expr(a) for a in args] return IntervalExpr( expr_type=ExprType.MAX, operands=exprs, )