"""Multi-component wrapper for heterogeneous model switching.
This module provides the ``MultiComponent`` class for wrapping multiple
interchangeable simulation models (e.g., FEM, OpenSim, FMU pendulum) under a
unified interface. It enables dynamic mode switching during simulation
with automatic state synchronization between models.
Key Features:
- **Dynamic Mode Switching**: Switch between different simulation
models at runtime based on custom criteria (time, cached outputs,
events)
- **State Synchronization**: Automatic state transfer and adaptation
when switching between models with different interfaces
- **Hysteresis Protection**: Configurable dwell time to prevent
rapid chattering between modes
- **Port Unification**: Validates that all sub-models have compatible
port interfaces
- **Event Delegation**: Transparently delegates hybrid event detection
to the currently active sub-component
Typical Use Cases:
- Multi-fidelity simulation: Switch between high-fidelity FEM and
reduced-order models based on accuracy requirements
- Contact dynamics: Use detailed contact model only when contact
is imminent, otherwise use simpler dynamics
- Adaptive resolution: Increase model complexity in regions of
interest, decrease elsewhere
Example:
Creating a multi-model pendulum::
class MasterPendulum(MultiComponent):
def __init__(self, fem, opensim, fmu):
super().__init__(
name="Pendulum",
models={"FEM": fem, "OpenSim": opensim, "FMU": fmu},
initial_mode="FEM",
)
def _adapt_state(self, state, target_mode):
if target_mode == "FMU":
return {'q0': state['q'], 'omega0': state['omega']}
return state
# Use with mode selector
pendulum = MasterPendulum(fem, opensim, fmu)
pendulum.mode_selector = lambda t: "FEM" if t < 1.0 else "FMU"
pendulum.hysteresis = Hysteresis(dwell_time=0.05)
See Also:
:class:`CoSimComponent`: Base class for all components
:class:`Hysteresis`: Mode switching debounce utility
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any, Protocol
from .base import CoSimComponent
from .events import InternalEventInfo
from .port import PortSpec, PortType
logger = logging.getLogger(__name__)
# -------------------------------------------------------------------
# Type Aliases
# -------------------------------------------------------------------
ModeKey = str # e.g., "FEM", "OpenSim", "FMU"
# -------------------------------------------------------------------
# State Adapter Protocol (for incompatible component interfaces)
# -------------------------------------------------------------------
[docs]
class StateAdapter(Protocol):
"""Protocol for adapting state between components with different interfaces.
Implement this protocol to provide custom state translation logic
when switching between models that use different state variable names,
units, or representations.
Example:
>>> class FMUAdapter:
... def adapt_state(self, source_state, target_component):
... # FMU uses 'q0', 'omega0' instead of 'q', 'omega'
... return {
... 'q0': source_state['q'],
... 'omega0': source_state['omega']
... }
"""
[docs]
def adapt_state(
self, source_state: dict[str, Any], target_component: CoSimComponent
) -> dict[str, Any]:
"""Convert state from source format to target component's format.
Args:
source_state: State dictionary from the source component,
typically in the format returned by ``get_state()``.
target_component: The component that will receive the
adapted state via ``set_state()``.
Returns:
Adapted state dictionary compatible with the target
component's ``set_state()`` method.
"""
...
# -------------------------------------------------------------------
# Hysteresis for Mode Switching
# -------------------------------------------------------------------
[docs]
class Hysteresis:
"""Minimum dwell time between mode switches.
Prevents chattering by enforcing a minimum elapsed time between
consecutive switches. The caller decides whether the proposed mode
differs from the current one. This class only answers the timing
question "is the dwell window still open?".
Attributes:
dwell_time (float): Minimum time in seconds that must elapse
between consecutive mode switches.
last_switch_time (float): Timestamp of the most recent switch.
Initialized to ``-inf`` so the first switch is always allowed.
Example:
>>> hyst = Hysteresis(dwell_time=0.05)
>>> hyst.in_dwell_window(t=0.02)
False # No prior switch yet
>>> hyst.record_switch(t=0.10)
>>> hyst.in_dwell_window(t=0.12)
True # Only 20 ms since last switch
>>> hyst.in_dwell_window(t=0.20)
False # 100 ms elapsed, window closed
"""
[docs]
def __init__(self, dwell_time: float = 0.01):
"""Initialize hysteresis with the given dwell time.
Args:
dwell_time: Minimum time in seconds between mode switches.
Defaults to 0.01 (10 ms).
"""
self.dwell_time = dwell_time
self.last_switch_time: float = -float("inf")
[docs]
def in_dwell_window(self, t: float) -> bool:
"""Return ``True`` if the dwell window after the last switch is still open."""
return (t - self.last_switch_time) < self.dwell_time
[docs]
def record_switch(self, t: float) -> None:
"""Record that a switch occurred at time ``t``."""
self.last_switch_time = t
# -------------------------------------------------------------------
# Abstract MultiComponent Base Class
# -------------------------------------------------------------------
[docs]
class MultiComponent(CoSimComponent):
"""Abstract base class for components wrapping multiple interchangeable models.
``MultiComponent`` enables dynamic switching between different simulation
models during runtime while presenting a unified interface to the rest
of the co-simulation system. Each sub-model ("mode") can use a different
solver, fidelity level, or physics representation.
Subclass Responsibilities:
1. Construct the sub-components and pass them to ``super().__init__``
through the ``models`` argument together with ``initial_mode``.
2. Override ``_adapt_state()`` for component-specific state translation.
3. (Optional) Set ``self.mode_selector`` for custom switching logic.
4. (Optional) Set ``self.hysteresis`` for chattering prevention.
Base Class Handles:
- Port unification (validates all models have compatible ports)
- Mode switching with hysteresis protection
- State synchronization during mode transitions
- Input/output delegation to the active component
- Event indicator delegation for hybrid simulation
Attributes:
models (dict[ModeKey, CoSimComponent]): Registry mapping mode keys
(e.g., "FEM", "OpenSim") to component instances. Populated in
``__init__`` and fixed for the lifetime of the wrapper.
active_mode (ModeKey): Key of the currently active model.
active_comp (CoSimComponent): Reference to the currently active
component instance. Always set after ``__init__``.
mode_selector (Callable | None): Function ``(t) -> ModeKey`` that
determines which mode should be active. Selectors that need state
information must read cached output ports rather than calling
``get_state()``, which can be expensive for high-fidelity models.
If ``None``, no automatic switching occurs.
hysteresis (Hysteresis | None): Optional hysteresis controller
to prevent rapid mode switching.
state_adapters (dict[ModeKey, StateAdapter]): Optional per-mode
state adapters for complex translation logic.
sync_events (list): Log of mode switch events for debugging.
Example:
Minimal subclass implementation::
class DualPendulum(MultiComponent):
def __init__(self, detailed, simplified):
super().__init__(
"Pendulum",
models={"detailed": detailed, "simplified": simplified},
initial_mode="detailed",
)
def _adapt_state(self, state, target_mode):
# Both models use same state format
return state
See Also:
:class:`CoSimComponent`: Parent class with full interface docs
:class:`Hysteresis`: Mode switching debounce utility
:class:`StateAdapter`: Protocol for state translation
"""
[docs]
def __init__(
self,
name: str,
models: dict[ModeKey, CoSimComponent],
initial_mode: ModeKey,
group: str | None = None,
):
"""Initialize a multi-component wrapper.
Args:
name: Unique identifier for this component in the system.
models: Mapping of mode keys to component instances. Must
contain at least ``initial_mode`` and must not be empty.
initial_mode: Key of the model to activate initially. Must
be a key in ``models``.
group: Optional category for component organization.
Raises:
ValueError: If ``models`` is empty or ``initial_mode`` is not
a key in ``models``.
Example:
>>> super().__init__(
... name="Pendulum",
... models={"FEM": fem, "FMU": fmu},
... initial_mode="FEM",
... group="Plant",
... )
"""
if not models:
raise ValueError(f"{name}: 'models' must not be empty")
if initial_mode not in models:
raise ValueError(
f"{name}: initial_mode '{initial_mode}' not in models {list(models.keys())}"
)
super().__init__(name, label=name, group=group)
# Model registry and active references (fixed at construction time)
self.models: dict[ModeKey, CoSimComponent] = models
self.active_mode: ModeKey = initial_mode
self.active_comp: CoSimComponent = models[initial_mode]
# Mode selection strategy (default: never switch)
self.mode_selector: Callable[[float], ModeKey] | None = None
# Hysteresis for switching (default: no hysteresis)
self.hysteresis: Hysteresis | None = None
# State adapters (optional): {mode_key: adapter}
self.state_adapters: dict[ModeKey, StateAdapter] = {}
# List of synchronization events (for logging/debugging)
self.sync_events: list = []
# Flag to prevent mode switching during event detection
self._allow_mode_switching: bool = True
# Latest input dict and timestamp seen by set_inputs. Used to
# bring a newly activated model up to date during a mode switch
# without forwarding inputs to inactive models on every step.
self._latest_inputs: tuple[dict[str, Any], float | None] | None = None
# When True, switch records in ``sync_events`` include the
# pre-adaptation source state and the synchronized target state.
# Default False because reading the target state calls
# ``active_comp.get_state()`` once per switch, which can be
# expensive for high-fidelity models.
self.record_switch_state: bool = False
# Previous and current state for synchronization
self._prev_state: dict[str, Any] | None = None
self._curr_state: dict[str, Any] | None = None
# -------------------------------------------------------------------
# State Adaptation Hook
# -------------------------------------------------------------------
[docs]
def _adapt_state(self, state: dict[str, Any], target_mode: ModeKey) -> dict[str, Any]:
"""Adapt state dictionary for the target model's interface.
Subclasses must override this method to translate state between
models that use different variable names, units, or representations.
Called during mode switching to transform the current model's state
into a format the target model can accept.
Args:
state: State dictionary from the current active component,
as returned by ``get_state()``.
target_mode: Key of the model being switched to.
Returns:
Adapted state dictionary compatible with the target model's
``set_state()`` method.
Raises:
NotImplementedError: If not overridden by subclass.
Example:
>>> def _adapt_state(self, state, target_mode):
... if target_mode == "FMU":
... # FMU uses initial condition naming
... return {
... 'q0': state['q'],
... 'omega0': state['omega'],
... 'torque': state['torque']
... }
... return state # Other models use standard naming
Note:
This is the primary extension point for handling heterogeneous
model interfaces. If models share identical state formats,
simply return ``state`` unchanged.
"""
raise NotImplementedError(f"{self.name}: Subclass must implement _adapt_state()")
# -------------------------------------------------------------------
# Initialization Logic
# -------------------------------------------------------------------
[docs]
def _initialize_component(self, t0: float) -> None:
"""Initialize all registered sub-components at time ``t0``.
Models and the active component are fixed by ``__init__``. This
hook only initializes each registered sub-component so that any
of them is ready for activation on a later mode switch.
Args:
t0: Initial simulation time in seconds.
Note:
All sub-components are initialized, not just the active one.
"""
for comp in self.models.values():
if comp is not None:
comp.initialize(t0)
# -------------------------------------------------------------------
# Port Unification and Validation
# -------------------------------------------------------------------
[docs]
@staticmethod
def _validate_port_compatibility(
ref_spec: PortSpec, spec: PortSpec, model_name: str, port_name: str
) -> None:
"""Validate that two PortSpecs are compatible for MultiComponent use."""
if ref_spec.name != port_name or spec.name != port_name:
raise ValueError(
f"Port name mismatch for '{port_name}' in model '{model_name}': "
f"got '{spec.name}', expected '{port_name}'."
)
if ref_spec.direction != spec.direction:
raise ValueError(
f"Port direction mismatch for '{port_name}' in model '{model_name}': "
f"{ref_spec.direction} vs {spec.direction}."
)
if ref_spec.type != spec.type:
raise ValueError(
f"Port type mismatch for '{port_name}' in model '{model_name}': "
f"{ref_spec.type} vs {spec.type}."
)
if not PortSpec.compatible(ref_spec, spec):
raise ValueError(
f"Port unit/type incompatibility for '{port_name}' in model '{model_name}': "
f"{ref_spec} vs {spec}."
)
[docs]
def _unify_ports(self) -> None:
"""Adopt port specifications from active component and validate compatibility.
Copies input and output port specifications from the active component
to this ``MultiComponent``, then validates that all registered models
have compatible port interfaces.
Raises:
ValueError: If any model is missing a required input or output
port that exists in the active component's specification.
Note:
This ensures the ``MultiComponent`` presents a consistent interface
regardless of which sub-model is active. All models must have at
least the same ports as the active component (they may have more).
"""
# Adopt active component's port specs
active_comp = self.active_comp
self.input_specs = active_comp.input_specs.copy()
self.output_specs = active_comp.output_specs.copy()
# Validate: all models must have compatible ports
for mode_key, comp in self.models.items():
if comp is None:
continue
# Check inputs
for name, spec in self.input_specs.items():
if name not in comp.input_specs:
raise ValueError(f"{self.name}: Model '{mode_key}' missing input port '{name}'")
self._validate_port_compatibility(spec, comp.input_specs[name], mode_key, name)
# Check outputs
for name, spec in self.output_specs.items():
if name not in comp.output_specs:
raise ValueError(
f"{self.name}: Model '{mode_key}' missing output port '{name}'"
)
self._validate_port_compatibility(spec, comp.output_specs[name], mode_key, name)
# -------------------------------------------------------------------
# Time Stepping with Mode Switching
# -------------------------------------------------------------------
[docs]
def _do_step_internal(self, t: float, dt: float) -> None:
"""Execute one macro step, switching modes first if requested.
Args:
t: Current simulation time in seconds.
dt: Macro step size in seconds.
Note:
Mode switching can be temporarily disabled by setting
``_allow_mode_switching = False``. This is used by the hybrid
algorithm during trial steps so that event detection does not
change the active model while a rollback snapshot is valid.
"""
if dt <= 0.0:
self.active_comp.do_step(t, dt)
return
target_mode = self._select_target_mode(t)
if target_mode != self.active_mode:
self._switch_mode(target_mode, t)
self.active_comp.do_step(t, dt)
[docs]
def _select_target_mode(self, t: float) -> ModeKey:
"""Return the desired mode at ``t``, honoring switching guards.
Returns the current ``active_mode`` when switching is disabled,
when no selector is configured, or when the hysteresis dwell
window is still open. Otherwise returns the selector's proposal.
Args:
t: Current simulation time in seconds.
Returns:
The mode key that should be active for the next step. Equal to
``self.active_mode`` if no switch is requested or allowed.
"""
if not self._allow_mode_switching or self.mode_selector is None:
return self.active_mode
if self.hysteresis is not None and self.hysteresis.in_dwell_window(t):
return self.active_mode
return self.mode_selector(t)
# -------------------------------------------------------------------
# Mode Switching with State Synchronization
# -------------------------------------------------------------------
[docs]
def _switch_mode(self, new_mode: ModeKey, t: float) -> None:
"""Switch to a new mode with state synchronization.
Orchestrates the transition. Validates the target, transfers the
adapted state to the new active component, records the switch event
for inspection, and notifies the hysteresis controller.
Args:
new_mode: Key of the mode to switch to. Must exist in
``self.models`` and be non-None.
t: Current simulation time at which the switch occurs.
Raises:
ValueError: If ``new_mode`` is not in the models registry.
RuntimeError: If the target model is ``None``.
"""
if new_mode not in self.models:
raise ValueError(f"{self.name}: Unknown mode '{new_mode}'")
new_comp = self.models[new_mode]
if new_comp is None:
raise RuntimeError(f"{self.name}: Model '{new_mode}' is not initialized")
from_mode = self.active_mode
logger.info("[%s] Switching: %s to %s @ t=%.4fs", self.name, from_mode, new_mode, t)
retrieved_state = self._perform_state_transfer(new_comp, new_mode, t)
self._capture_switch_event(t, from_mode, new_mode, retrieved_state)
if self.hysteresis is not None:
self.hysteresis.record_switch(t)
[docs]
def _capture_switch_event(
self, t: float, from_mode: ModeKey, to_mode: ModeKey, retrieved: dict[str, Any]
) -> None:
"""Append one record of the completed switch to ``sync_events``.
Always logs the time, source mode, and target mode. When
``self.record_switch_state`` is ``True``, the record also
includes the pre-adaptation source state (``retrieved``) and
a fresh snapshot of the new active component's state (``now``).
The ``now`` snapshot calls ``active_comp.get_state()``, which
can be expensive for high-fidelity models. ``record_switch_state``
defaults to ``False`` and should be enabled only for debugging
synchronization issues.
Args:
t: Time at which the switch occurred.
from_mode: Mode key that was active before the switch.
to_mode: Mode key that is active after the switch.
retrieved: State exported from the source component before
adaptation.
"""
record: dict[str, Any] = {
"time": t,
"from_mode": from_mode,
"to_mode": to_mode,
}
if self.record_switch_state:
record["retrieved"] = retrieved
record["now"] = self.active_comp.get_state()
self.sync_events.append(record)
# -------------------------------------------------------------------
# Input/Output Delegation
# -------------------------------------------------------------------
[docs]
def _update_output_states(
self, t: float | None = None, event_names: list[str] | None = None
) -> None:
"""Copy output values from the active component to this wrapper.
Reads all output values from the active sub-component and writes
them to this ``MultiComponent``'s output ports. Also handles event
port updates based on which events fired.
Args:
t: Current simulation time for timestamping port values.
event_names: List of event names that just occurred. Event
ports matching these names are set to ``True``; others
are set to ``False``.
Note:
This ensures the ``MultiComponent`` always reflects the active
component's outputs, regardless of which model is active.
"""
active_comp = self.active_comp
for name in self.output_specs.keys():
value = active_comp.outputs[name].get()
if value is not None:
self.outputs[name].set(value, t=t)
if event_names:
for event_name in event_names:
if event_name in self.output_specs.keys():
self.outputs[event_name].set(value=True, t=t)
else:
for out_port in self.outputs.values():
if out_port.spec.type == PortType.EVENT:
out_port.set(value=False, t=t)
[docs]
def evaluate_outputs(self, inputs: dict[str, Any], t: float | None = None) -> dict[str, Any]:
saved = self._allow_mode_switching
self._allow_mode_switching = False
try:
outputs = self.active_comp.evaluate_outputs(inputs, t=t)
for name, value in outputs.items():
if name in self.outputs and value is not None:
self.outputs[name].set(value, t=t)
return outputs
finally:
self._allow_mode_switching = saved
# -------------------------------------------------------------------
# State Management Delegation
# -------------------------------------------------------------------
[docs]
def set_state(self, state: dict[str, Any], t: float) -> None:
"""Set state on the active component with adaptation.
Adapts the provided state for the active model's interface using
``_adapt_state()``, then delegates to the active component.
Args:
state: State dictionary to set. Will be adapted for the
active model's expected format.
t: Time at which to set the state.
See Also:
:meth:`_adapt_state`: State translation hook
"""
adapted_state = self._adapt_state(state, self.active_mode)
self.active_comp.set_state(adapted_state, t)
[docs]
def get_state(self) -> dict[str, Any]:
"""Get the current state from the active component.
Returns:
State dictionary from the active sub-component, in that
component's native format.
Note:
The returned state format depends on which model is active.
Use ``_adapt_state()`` if you need to translate to another
model's format.
"""
return self.active_comp.get_state()
# -------------------------------------------------------------------
# Hybrid Capabilities Delegation
# -------------------------------------------------------------------
[docs]
def add_event_indicator(self, name: str, func: Callable, direction: int = 0) -> None:
"""Register an event indicator on all sub-components.
Adds the event indicator to every sub-component that supports
rollback, ensuring consistent event detection regardless of
which model is active.
Args:
name: Unique name for the event indicator.
func: Callable ``(component) -> float`` that returns the
indicator value. Should work with any sub-component.
direction: Zero-crossing direction: -1 (falling), 0 (both),
+1 (rising).
Note:
The indicator function should access state through the
unified interface (e.g., ``comp.get_outputs()``) rather
than model-specific internals to work across all models.
"""
for comp in self.models.values():
if comp is not None and comp.supports_rollback:
comp.add_event_indicator(name, func, direction)
# Also add to self for port management
super().add_event_indicator(name, func, direction)
[docs]
def evaluate_event_indicators(self) -> dict[str, float]:
"""Evaluate event indicators on the active component.
Delegates to the active sub-component's event indicator
evaluation if it has state events configured.
Returns:
Dictionary mapping indicator names to their current values.
Empty dict if active component has no event indicators.
"""
if self.active_comp.has_state_events:
return self.active_comp.evaluate_event_indicators()
return {}
[docs]
def detect_event_crossings(
self, previous: dict[str, float], current: dict[str, float], sign_tolerance: float = 1e-10
) -> list[str]:
"""Detect zero-crossings on the active component.
Delegates to the active sub-component's crossing detection
if it has state events configured.
Args:
previous: Indicator values before the step.
current: Indicator values after the step.
sign_tolerance: Threshold for zero detection.
Returns:
List of indicator names that experienced crossings.
Empty list if active component has no event indicators.
"""
if self.active_comp.has_state_events:
return self.active_comp.detect_event_crossings(previous, current, sign_tolerance)
return []
[docs]
def snapshot_state(self):
"""Capture state snapshot from the active component.
Delegates to the active sub-component's snapshot mechanism.
Used for time rollback during event localization.
Returns:
Opaque snapshot from the active component.
Warning:
The snapshot is only valid for restoration to the same
active component. Mode switches invalidate snapshots.
"""
return self.active_comp.snapshot_state()
[docs]
def restore_state(self, snapshot, t) -> None:
"""Restore state snapshot on the active component.
Delegates to the active sub-component's restore mechanism.
Used to roll back time during event localization bisection.
Args:
snapshot: Opaque snapshot from ``snapshot_state()``.
t: Time at which the snapshot was taken.
Warning:
Must restore to the same component that created the snapshot.
Do not switch modes between snapshot and restore.
"""
self.active_comp.restore_state(snapshot, t)
@property
def has_state_events(self) -> bool:
"""``True`` if the currently active sub-component has event indicators."""
return self.active_comp.has_state_events
@property
def supports_rollback(self) -> bool:
"""``True`` if the currently active sub-component supports state rollback."""
return self.active_comp.supports_rollback
[docs]
def _handle_events_internal(self, event_names: list[str], t: float) -> None:
"""Delegate event handling to the active component.
Args:
event_names: List of events that occurred at time ``t``.
t: Precise time at which the events occurred.
"""
self.active_comp.handle_event(event_names, t)
[docs]
def get_internal_event_hints(self) -> list[InternalEventInfo]:
"""Retrieve internal event hints from the active component.
Forwarding is unconditional so that hints reported by the active
model during a trial step are visible to the hybrid algorithm and
can short-circuit bisection.
Returns:
List of ``InternalEventInfo`` objects from the active component.
"""
return self.active_comp.get_internal_event_hints()
# -------------------------------------------------------------------
# Detect Direct Feedthrough
# -------------------------------------------------------------------
[docs]
def _detect_direct_feedthrough(self):
"""Determine if all models have consistent direct feedthrough.
Checks the ``direct_feedthrough`` property of all registered
sub-components. If they differ, raises an error. Otherwise,
sets this ``MultiComponent``'s ``direct_feedthrough`` property
accordingly.
"""
self.direct_feedthrough = None
for mode_key, comp in self.models.items():
if comp is None:
continue
if self.direct_feedthrough is None:
self.direct_feedthrough = comp.direct_feedthrough
elif self.direct_feedthrough != comp.direct_feedthrough:
raise ValueError(
f"{self.name}: Inconsistent direct feedthrough across models. "
f"Model '{mode_key}' has direct_feedthrough={comp.direct_feedthrough}, "
f"expected {self.direct_feedthrough}."
)
# -------------------------------------------------------------------
# Reset Logic
# -------------------------------------------------------------------
[docs]
def reset(self) -> None:
"""Reset all registered sub-components.
Calls ``reset()`` on every non-None model in the registry,
clearing their state and allowing re-initialization. Also
clears the cached input replay buffer.
Note:
Unlike the base class, this resets ALL models, not just
the active one. This ensures clean state when the
``MultiComponent`` is re-initialized.
"""
super().reset()
for comp in self.models.values():
if comp is not None:
comp.reset()
self._latest_inputs = None