"""STIR adaptor implemented on top of the connector-first NumPy pipeline."""
from __future__ import annotations
from typing import Any, Dict, Optional
import numpy as np
from simind_python_connector.connectors._spacing import extract_voxel_size_mm
from simind_python_connector.connectors.base import BaseConnector
from simind_python_connector.connectors.python_connector import (
ConfigSource,
RuntimeOperator,
SimindPythonConnector,
)
from simind_python_connector.core.config import SimulationConfig
from simind_python_connector.core.types import PenetrateOutputType, ScoringRoutine
from simind_python_connector.utils import get_array
try:
import stir
except ImportError: # pragma: no cover - optional dependency
stir = None # type: ignore[assignment]
[docs]
class StirSimindAdaptor(BaseConnector):
"""Adaptor consuming/returning STIR-native objects."""
[docs]
def __init__(
self,
config_source: ConfigSource,
output_dir: str,
output_prefix: str = "output",
photon_multiplier: int = 1,
quantization_scale: float = 1.0,
scoring_routine: ScoringRoutine | int = ScoringRoutine.SCATTWIN,
) -> None:
if stir is None:
raise ImportError("StirSimindAdaptor requires the STIR Python package.")
self.python_connector = SimindPythonConnector(
config_source=config_source,
output_dir=output_dir,
output_prefix=output_prefix,
quantization_scale=quantization_scale,
)
self._scoring_routine = (
ScoringRoutine(scoring_routine)
if isinstance(scoring_routine, int)
else scoring_routine
)
self._source: Any = None
self._mu_map: Any = None
self._outputs: Optional[dict[str, Any]] = None
self.add_runtime_switch("NN", photon_multiplier)
[docs]
def set_source(self, source: Any) -> None:
self._source = source
[docs]
def set_mu_map(self, mu_map: Any) -> None:
self._mu_map = mu_map
[docs]
def set_energy_windows(
self,
lower_bounds: float | list[float],
upper_bounds: float | list[float],
scatter_orders: int | list[int],
) -> None:
self.python_connector.set_energy_windows(
lower_bounds, upper_bounds, scatter_orders
)
[docs]
def add_config_value(self, index: int, value: Any) -> None:
self.python_connector.add_config_value(index, value)
[docs]
def add_runtime_switch(self, switch: str, value: Any) -> None:
self.python_connector.add_runtime_switch(switch, value)
[docs]
def run(self, runtime_operator: Optional[RuntimeOperator] = None) -> dict[str, Any]:
self._validate_inputs()
assert self._source is not None
assert self._mu_map is not None
source_arr = np.asarray(get_array(self._source), dtype=np.float32)
mu_arr = np.asarray(get_array(self._mu_map), dtype=np.float32)
voxel_size_mm = self._extract_voxel_size_mm(self._source)
self.python_connector.configure_voxel_phantom(
source=source_arr,
mu_map=mu_arr,
voxel_size_mm=voxel_size_mm,
scoring_routine=self._scoring_routine,
)
raw_outputs = self.python_connector.run(runtime_operator=runtime_operator)
self._outputs = {
key: stir.ProjData.read_from_file(str(result.header_path))
for key, result in raw_outputs.items()
}
return self._outputs
[docs]
def get_outputs(self) -> Dict[str, Any]:
if self._outputs is None:
raise RuntimeError("Run the adaptor first to produce outputs.")
return self._outputs
[docs]
def get_total_output(self, window: int = 1) -> Any:
return self._get_component("tot", window)
[docs]
def get_scatter_output(self, window: int = 1) -> Any:
return self._get_component("sca", window)
[docs]
def get_primary_output(self, window: int = 1) -> Any:
return self._get_component("pri", window)
[docs]
def get_air_output(self, window: int = 1) -> Any:
return self._get_component("air", window)
[docs]
def get_penetrate_output(self, component: PenetrateOutputType | str) -> Any:
outputs = self.get_outputs()
key = (
component.slug if isinstance(component, PenetrateOutputType) else component
)
if key not in outputs:
available = ", ".join(sorted(outputs))
raise KeyError(f"Output {key!r} not available. Available: {available}")
return outputs[key]
[docs]
def list_available_outputs(self) -> list[str]:
return sorted(self.get_outputs().keys())
[docs]
def get_scoring_routine(self) -> ScoringRoutine:
return self._scoring_routine
[docs]
def get_config(self) -> SimulationConfig:
return self.python_connector.get_config()
def _get_component(self, prefix: str, window: int) -> Any:
outputs = self.get_outputs()
key = f"{prefix}_w{window}"
if key not in outputs:
available = ", ".join(sorted(outputs))
raise KeyError(f"Output {key!r} not available. Available: {available}")
return outputs[key]
@staticmethod
def _extract_voxel_size_mm(image: Any) -> float:
return extract_voxel_size_mm(image=image, backend_name="STIR")
def _validate_inputs(self) -> None:
if self._source is None or self._mu_map is None:
raise ValueError("Both source and mu_map must be set before run().")
source_shape = np.asarray(get_array(self._source)).shape
mu_shape = np.asarray(get_array(self._mu_map)).shape
if source_shape != mu_shape:
raise ValueError(
f"source and mu_map must have matching shapes, got "
f"{source_shape} and {mu_shape}"
)
__all__ = ["StirSimindAdaptor"]