Skip to content

nnUNet API utilities

Utilities for nnU-Net model serving. The core functions handling the prediction workflows are:

  1. single_model_inference - runs the inference for a single model.
  2. multi_model_inference - runs the inference using multiple models (i.e. a model cascade).
  3. predict - wrapper around multi_model_inference which also handles file saving.

LabelManagerAdapter

Adapter around nnU-Net LabelManager private state mutations.

Centralizes private attribute access (_all_labels, _regions) so calling code does not directly mutate them.

Source code in src/nnunet_serve/nnunet_api_utils.py
class LabelManagerAdapter:
    """
    Adapter around nnU-Net LabelManager private state mutations.

    Centralizes private attribute access (`_all_labels`, `_regions`) so calling
    code does not directly mutate them.
    """

    def __init__(self, label_manager: LabelManager):
        """Initializes the adapter with a given LabelManager.

        Args:
            label_manager (LabelManager): The nnU-Net LabelManager to adapt.
        """
        self.label_manager = label_manager

    def build_subset_state(
        self, class_idx: int | list[int]
    ) -> LabelSubsetState:
        """Build the temporary label mapping state for a class subset.

        Args:
            class_idx (int | list[int]): The index or indices of the
                classes to include.

        Returns:
            LabelSubsetState: The constructed subset state for inference.
        """
        if isinstance(class_idx, int):
            class_idx = [class_idx]
        used_labels = [0, *class_idx]
        correspondence_dict = {i: label for i, label in enumerate(used_labels)}

        if self.label_manager._has_regions:
            new_regions = [
                region
                for region in self.label_manager._regions
                if any(region_label in used_labels for region_label in region)
            ]
            if new_regions:
                used_labels = sorted(
                    np.unique(np.concatenate(new_regions)).tolist()
                )

        return LabelSubsetState(
            prediction_indices=[0, *class_idx],
            used_labels=used_labels,
            correspondence_dict=correspondence_dict,
        )

    @contextmanager
    def apply_subset_state(self, subset_state: LabelSubsetState):
        """
        Temporarily apply subset labels and always restore original state.

        Args:
            subset_state (LabelSubsetState): The state to apply during the context.

        Yields:
            None: Allows execution within the modified label manager state.
        """
        old_labels = self.label_manager._all_labels
        old_regions = self.label_manager._regions
        self.label_manager._all_labels = subset_state.used_labels
        try:
            yield
        finally:
            self.label_manager._all_labels = old_labels
            self.label_manager._regions = old_regions

__init__(label_manager)

Initializes the adapter with a given LabelManager.

Parameters:

Name Type Description Default
label_manager LabelManager

The nnU-Net LabelManager to adapt.

required
Source code in src/nnunet_serve/nnunet_api_utils.py
def __init__(self, label_manager: LabelManager):
    """Initializes the adapter with a given LabelManager.

    Args:
        label_manager (LabelManager): The nnU-Net LabelManager to adapt.
    """
    self.label_manager = label_manager

apply_subset_state(subset_state)

Temporarily apply subset labels and always restore original state.

Parameters:

Name Type Description Default
subset_state LabelSubsetState

The state to apply during the context.

required

Yields:

Name Type Description
None

Allows execution within the modified label manager state.

Source code in src/nnunet_serve/nnunet_api_utils.py
@contextmanager
def apply_subset_state(self, subset_state: LabelSubsetState):
    """
    Temporarily apply subset labels and always restore original state.

    Args:
        subset_state (LabelSubsetState): The state to apply during the context.

    Yields:
        None: Allows execution within the modified label manager state.
    """
    old_labels = self.label_manager._all_labels
    old_regions = self.label_manager._regions
    self.label_manager._all_labels = subset_state.used_labels
    try:
        yield
    finally:
        self.label_manager._all_labels = old_labels
        self.label_manager._regions = old_regions

build_subset_state(class_idx)

Build the temporary label mapping state for a class subset.

Parameters:

Name Type Description Default
class_idx int | list[int]

The index or indices of the classes to include.

required

Returns:

Name Type Description
LabelSubsetState LabelSubsetState

The constructed subset state for inference.

Source code in src/nnunet_serve/nnunet_api_utils.py
def build_subset_state(
    self, class_idx: int | list[int]
) -> LabelSubsetState:
    """Build the temporary label mapping state for a class subset.

    Args:
        class_idx (int | list[int]): The index or indices of the
            classes to include.

    Returns:
        LabelSubsetState: The constructed subset state for inference.
    """
    if isinstance(class_idx, int):
        class_idx = [class_idx]
    used_labels = [0, *class_idx]
    correspondence_dict = {i: label for i, label in enumerate(used_labels)}

    if self.label_manager._has_regions:
        new_regions = [
            region
            for region in self.label_manager._regions
            if any(region_label in used_labels for region_label in region)
        ]
        if new_regions:
            used_labels = sorted(
                np.unique(np.concatenate(new_regions)).tolist()
            )

    return LabelSubsetState(
        prediction_indices=[0, *class_idx],
        used_labels=used_labels,
        correspondence_dict=correspondence_dict,
    )

LabelSubsetState dataclass

Holds temporary label-manager state used for class-subset inference.

Attributes:

Name Type Description
prediction_indices list[int]

Indices of the classes to be kept in the prediction.

used_labels list[int]

List of unique labels that are active in the current subset.

correspondence_dict dict[int, int]

Mapping from subset label indices to original labels.

Source code in src/nnunet_serve/nnunet_api_utils.py
@dataclass
class LabelSubsetState:
    """Holds temporary label-manager state used for class-subset inference.

    Attributes:
        prediction_indices (list[int]): Indices of the classes to be kept in the
            prediction.
        used_labels (list[int]): List of unique labels that are active in the
            current subset.
        correspondence_dict (dict[int, int]): Mapping from subset label indices
            to original labels.
    """

    prediction_indices: list[int]
    used_labels: list[int]
    correspondence_dict: dict[int, int]

SeriesLoader

Load and cache medical image series as SimpleITK volumes.

This helper provides: - Caching: each unique series path is read at most once. - Optional DICOM loading via read_dicom_as_sitk. - Optional on-access post-processing controlled by a suffix in the requested path.

The path string passed to __getitem__ (and used inside series_paths) can include simple modifiers: - "/path/to/seg.nii.gz=3": returns a binary mask (volume == 3) cast to sitkInt32. - "/path/to/4d.nii.gz[0]": returns the slice/volume at the given index.

Parameters:

Name Type Description Default
series_paths list[list[str]]

Nested list of series identifiers grouped by "stage". Each inner list is the set of series to be used at that stage.

required
is_dicom bool

If True, each series path is treated as a DICOM directory and read with read_dicom_as_sitk. If False, each series path is read using sitk.ReadImage.

False
bvalue_for_filtering int | None

If provided when is_dicom=True, filters DICOM series by the specified b-value using filter_by_bvalue. Defaults to None.

None
Source code in src/nnunet_serve/nnunet_api_utils.py
class SeriesLoader:
    """
    Load and cache medical image series as SimpleITK volumes.

    This helper provides:
    - Caching: each unique series path is read at most once.
    - Optional DICOM loading via `read_dicom_as_sitk`.
    - Optional on-access post-processing controlled by a suffix in the requested path.

    The path string passed to `__getitem__` (and used inside `series_paths`) can include
    simple modifiers:
    - `"/path/to/seg.nii.gz=3"`: returns a binary mask `(volume == 3)` cast to `sitkInt32`.
    - `"/path/to/4d.nii.gz[0]"`: returns the slice/volume at the given index.

    Args:
        series_paths (list[list[str]]): Nested list of series identifiers grouped by
            "stage". Each inner list is the set of series to be used at that stage.
        is_dicom (bool, optional): If `True`, each series path is treated as a DICOM
            directory and read with `read_dicom_as_sitk`. If `False`, each series
            path is read using `sitk.ReadImage`.
        bvalue_for_filtering (int | None, optional): If provided when `is_dicom=True`,
            filters DICOM series by the specified b-value using `filter_by_bvalue`.
            Defaults to None.
    """

    def __init__(
        self,
        series_paths: list[list[str]],
        is_dicom: bool = False,
        bvalue_for_filtering: int | None = None,
    ):
        """
        Create a loader and pre-compute the unique series paths.

        Note:
            Only the base path (without modifiers) is cached. For example,
            requesting `"image.nii.gz=1"` and `"image.nii.gz=2"` will load
            `"image.nii.gz"` once and apply post-processing per request.
        """
        self.series_paths = series_paths
        self.is_dicom = is_dicom
        self.bvalue_for_filtering = bvalue_for_filtering

        self.n_stages = len(self.series_paths)
        self.unique_series_paths = []

        for series_path in series_paths:
            for s in series_path:
                if s not in self.unique_series_paths:
                    self.unique_series_paths.append(s)

        self.loaded_volumes = {}
        self.good_file_paths = {}
        self.loaded_resampled_volumes = {}
        logger.info(f"Identified {len(self.unique_series_paths)} unique series")

    def __getitem__(self, path: str) -> tuple[sitk.Image, list[str] | None]:
        """
        Load a series (if needed) and return the (optionally processed) volume.

        Args:
            path (str): Series path, optionally suffixed with `=<label>` or `[<index>]`.

        Returns:
            A tuple `(volume, good_file_paths)` where:
                - `volume` is the loaded SimpleITK image, potentially post-processed.
                - `good_file_paths` is only populated for DICOM inputs and corresponds to
                the list of files considered valid by `read_dicom_as_sitk`.
        """

        path, equal, index = self.get_info(path)
        if path not in self.loaded_volumes:
            if self.is_dicom:
                if os.path.isdir(path):
                    dcm_img = read_dicom_as_sitk(
                        path, bvalue_for_filtering=self.bvalue_for_filtering
                    )
                    self.loaded_volumes[path] = dcm_img[0]
                    self.good_file_paths[path] = dcm_img[1]
                else:
                    self.loaded_volumes[path] = sitk.ReadImage(path)
                    self.good_file_paths[path] = None
            else:
                self.loaded_volumes[path] = sitk.ReadImage(path)
                self.good_file_paths[path] = None
        volume, good_file_paths = (
            self.loaded_volumes[path],
            self.good_file_paths[path],
        )
        volume = self.post_process(volume, equal=equal, index=index)
        return volume, good_file_paths

    def __setitem__(self, key: str, value: sitk.Image):
        """Manually set/cache a volume for a given path.

        Args:
            key (str): The series identifier or path.
            value (sitk.Image): The SimpleITK image to cache.
        """
        self.loaded_volumes[key] = value
        self.good_file_paths[key] = None

    def register(
        self,
        image: sitk.Image,
        stage: int,
        image_file_name: str = "prediction.nii.gz",
    ):
        """Registers an image for a specific stage.

        Args:
            image (sitk.Image): The image to register.
            stage (int): The stage index to associate the image with.
            image_file_name (str, optional): image file name presumed to be inside
                the stage directory. Defaults to "prediction.nii.gz".
        """
        key = os.path.join(f"stage_{stage}", image_file_name)
        self[key] = image

    def get_info(self, path: str) -> tuple[str, int | None, int | None]:
        """
        Parse a series path and extract any post-processing modifier.

        Supported syntaxes:
        - `"<path>=<int>"`: equality comparison against the given integer label.
        - `"<path>[<int>]"`: SimpleITK index selection.

        Args:
            path (str): Raw path string.

        Returns:
            A tuple `(base_path, equal, index)` where:
            - `base_path` is the path without the modifier.
            - `equal` is the parsed integer after `=` (or `None`).
            - `index` is the parsed integer inside `[...]` (or `None`).
        """
        equal, index = None, None
        if "=" in path:
            path, equal = path.split("=")
            equal = int(equal)
        elif "[" in path:
            path, index = path.split("[")
            index = int(index.split("]")[0])
        return path, equal, index

    def post_process(
        self,
        volume: sitk.Image,
        equal: int | None = None,
        index: int | None = None,
    ) -> sitk.Image:
        """
        Apply an optional post-processing operation to a loaded volume.

        Args:
            volume (sitk.Image): Input SimpleITK image.
            equal (int | None): If provided, returns `(volume == equal)` cast
                to `sitkInt32`.
            index (int | None): If provided (and `equal` is `None`), returns
                `volume[index]`.

        Returns:
            The processed (or original) image.
        """
        if equal is not None:
            return volume == equal
        elif index is not None:
            select_f = sitk.VectorIndexSelectionCastImageFilter()
            select_f.SetIndex(index)
            return select_f.Execute(volume)
        return sitk.Cast(volume, sitk.sitkFloat32)

    def get_volumes(self, stage: int) -> list[sitk.Image]:
        """
        Return the loaded (and post-processed) volumes for a given stage.

        Args:
            stage (int): Index into `series_paths`.

        Returns:
            List of `sitk.Image` volumes for the stage.
        """
        return [self.get_resampled(s, stage) for s in self.series_paths[stage]]

    def get_file_paths(self, stage: int) -> list[list[str]]:
        """Return the DICOM file-path lists for a given stage.

        Args:
            stage (int): Index into `series_paths`.

        Returns:
            List of file-path lists returned by `read_dicom_as_sitk` for the stage.
            For non-DICOM inputs, entries will be `None`.
        """
        return [self[s][1] for s in self.series_paths[stage]]

    def might_be_mask(self, image: sitk.Image) -> bool:
        """
        Makes an educated guess on whether ``image`` is a mask.

        Args:
            image (sitk.Image): image.

        Returns:
            Boolean which is True if image is an integer _and_ has fewer than
                50 unique values.
        """
        is_mask = False
        if "int" in image.GetPixelIDTypeAsString().lower():
            stats_filter = sitk.LabelShapeStatisticsImageFilter()
            stats_filter.Execute(image)
            if stats_filter.GetNumberOfLabels() < 50:
                is_mask = True
        return is_mask

    def get_resampled(self, path: str, stage: int) -> sitk.Image:
        """
        Similar to __getitem__ but resamples to first series in stage.

        Args:
            path (str): Series path, optionally suffixed with `=<label>` or `[<index>]`.
            stage (int): Index into `series_paths`.

        Returns:
            sitk.Image: The loaded image, resampled to match the first series in the stage.
        """
        first_path = self.series_paths[stage][0]
        identifier = (path, first_path)
        if identifier not in self.loaded_resampled_volumes:
            if path == first_path:
                resampled_image = self[path][0]
            else:
                resampled_image = resample_image_to_target(
                    self[path][0],
                    self[first_path][0],
                    is_mask=self.might_be_mask(self[path][0]),
                )

            self.loaded_resampled_volumes[identifier] = resampled_image

        return self.loaded_resampled_volumes[identifier]

__getitem__(path)

Load a series (if needed) and return the (optionally processed) volume.

Parameters:

Name Type Description Default
path str

Series path, optionally suffixed with =<label> or [<index>].

required

Returns:

Type Description
tuple[Image, list[str] | None]

A tuple (volume, good_file_paths) where: - volume is the loaded SimpleITK image, potentially post-processed. - good_file_paths is only populated for DICOM inputs and corresponds to the list of files considered valid by read_dicom_as_sitk.

Source code in src/nnunet_serve/nnunet_api_utils.py
def __getitem__(self, path: str) -> tuple[sitk.Image, list[str] | None]:
    """
    Load a series (if needed) and return the (optionally processed) volume.

    Args:
        path (str): Series path, optionally suffixed with `=<label>` or `[<index>]`.

    Returns:
        A tuple `(volume, good_file_paths)` where:
            - `volume` is the loaded SimpleITK image, potentially post-processed.
            - `good_file_paths` is only populated for DICOM inputs and corresponds to
            the list of files considered valid by `read_dicom_as_sitk`.
    """

    path, equal, index = self.get_info(path)
    if path not in self.loaded_volumes:
        if self.is_dicom:
            if os.path.isdir(path):
                dcm_img = read_dicom_as_sitk(
                    path, bvalue_for_filtering=self.bvalue_for_filtering
                )
                self.loaded_volumes[path] = dcm_img[0]
                self.good_file_paths[path] = dcm_img[1]
            else:
                self.loaded_volumes[path] = sitk.ReadImage(path)
                self.good_file_paths[path] = None
        else:
            self.loaded_volumes[path] = sitk.ReadImage(path)
            self.good_file_paths[path] = None
    volume, good_file_paths = (
        self.loaded_volumes[path],
        self.good_file_paths[path],
    )
    volume = self.post_process(volume, equal=equal, index=index)
    return volume, good_file_paths

__init__(series_paths, is_dicom=False, bvalue_for_filtering=None)

Create a loader and pre-compute the unique series paths.

Note

Only the base path (without modifiers) is cached. For example, requesting "image.nii.gz=1" and "image.nii.gz=2" will load "image.nii.gz" once and apply post-processing per request.

Source code in src/nnunet_serve/nnunet_api_utils.py
def __init__(
    self,
    series_paths: list[list[str]],
    is_dicom: bool = False,
    bvalue_for_filtering: int | None = None,
):
    """
    Create a loader and pre-compute the unique series paths.

    Note:
        Only the base path (without modifiers) is cached. For example,
        requesting `"image.nii.gz=1"` and `"image.nii.gz=2"` will load
        `"image.nii.gz"` once and apply post-processing per request.
    """
    self.series_paths = series_paths
    self.is_dicom = is_dicom
    self.bvalue_for_filtering = bvalue_for_filtering

    self.n_stages = len(self.series_paths)
    self.unique_series_paths = []

    for series_path in series_paths:
        for s in series_path:
            if s not in self.unique_series_paths:
                self.unique_series_paths.append(s)

    self.loaded_volumes = {}
    self.good_file_paths = {}
    self.loaded_resampled_volumes = {}
    logger.info(f"Identified {len(self.unique_series_paths)} unique series")

__setitem__(key, value)

Manually set/cache a volume for a given path.

Parameters:

Name Type Description Default
key str

The series identifier or path.

required
value Image

The SimpleITK image to cache.

required
Source code in src/nnunet_serve/nnunet_api_utils.py
def __setitem__(self, key: str, value: sitk.Image):
    """Manually set/cache a volume for a given path.

    Args:
        key (str): The series identifier or path.
        value (sitk.Image): The SimpleITK image to cache.
    """
    self.loaded_volumes[key] = value
    self.good_file_paths[key] = None

get_file_paths(stage)

Return the DICOM file-path lists for a given stage.

Parameters:

Name Type Description Default
stage int

Index into series_paths.

required

Returns:

Type Description
list[list[str]]

List of file-path lists returned by read_dicom_as_sitk for the stage.

list[list[str]]

For non-DICOM inputs, entries will be None.

Source code in src/nnunet_serve/nnunet_api_utils.py
def get_file_paths(self, stage: int) -> list[list[str]]:
    """Return the DICOM file-path lists for a given stage.

    Args:
        stage (int): Index into `series_paths`.

    Returns:
        List of file-path lists returned by `read_dicom_as_sitk` for the stage.
        For non-DICOM inputs, entries will be `None`.
    """
    return [self[s][1] for s in self.series_paths[stage]]

get_info(path)

Parse a series path and extract any post-processing modifier.

Supported syntaxes: - "<path>=<int>": equality comparison against the given integer label. - "<path>[<int>]": SimpleITK index selection.

Parameters:

Name Type Description Default
path str

Raw path string.

required

Returns:

Type Description
str

A tuple (base_path, equal, index) where:

int | None
  • base_path is the path without the modifier.
int | None
  • equal is the parsed integer after = (or None).
tuple[str, int | None, int | None]
  • index is the parsed integer inside [...] (or None).
Source code in src/nnunet_serve/nnunet_api_utils.py
def get_info(self, path: str) -> tuple[str, int | None, int | None]:
    """
    Parse a series path and extract any post-processing modifier.

    Supported syntaxes:
    - `"<path>=<int>"`: equality comparison against the given integer label.
    - `"<path>[<int>]"`: SimpleITK index selection.

    Args:
        path (str): Raw path string.

    Returns:
        A tuple `(base_path, equal, index)` where:
        - `base_path` is the path without the modifier.
        - `equal` is the parsed integer after `=` (or `None`).
        - `index` is the parsed integer inside `[...]` (or `None`).
    """
    equal, index = None, None
    if "=" in path:
        path, equal = path.split("=")
        equal = int(equal)
    elif "[" in path:
        path, index = path.split("[")
        index = int(index.split("]")[0])
    return path, equal, index

get_resampled(path, stage)

Similar to getitem but resamples to first series in stage.

Parameters:

Name Type Description Default
path str

Series path, optionally suffixed with =<label> or [<index>].

required
stage int

Index into series_paths.

required

Returns:

Type Description
Image

sitk.Image: The loaded image, resampled to match the first series in the stage.

Source code in src/nnunet_serve/nnunet_api_utils.py
def get_resampled(self, path: str, stage: int) -> sitk.Image:
    """
    Similar to __getitem__ but resamples to first series in stage.

    Args:
        path (str): Series path, optionally suffixed with `=<label>` or `[<index>]`.
        stage (int): Index into `series_paths`.

    Returns:
        sitk.Image: The loaded image, resampled to match the first series in the stage.
    """
    first_path = self.series_paths[stage][0]
    identifier = (path, first_path)
    if identifier not in self.loaded_resampled_volumes:
        if path == first_path:
            resampled_image = self[path][0]
        else:
            resampled_image = resample_image_to_target(
                self[path][0],
                self[first_path][0],
                is_mask=self.might_be_mask(self[path][0]),
            )

        self.loaded_resampled_volumes[identifier] = resampled_image

    return self.loaded_resampled_volumes[identifier]

get_volumes(stage)

Return the loaded (and post-processed) volumes for a given stage.

Parameters:

Name Type Description Default
stage int

Index into series_paths.

required

Returns:

Type Description
list[Image]

List of sitk.Image volumes for the stage.

Source code in src/nnunet_serve/nnunet_api_utils.py
def get_volumes(self, stage: int) -> list[sitk.Image]:
    """
    Return the loaded (and post-processed) volumes for a given stage.

    Args:
        stage (int): Index into `series_paths`.

    Returns:
        List of `sitk.Image` volumes for the stage.
    """
    return [self.get_resampled(s, stage) for s in self.series_paths[stage]]

might_be_mask(image)

Makes an educated guess on whether image is a mask.

Parameters:

Name Type Description Default
image Image

image.

required

Returns:

Type Description
bool

Boolean which is True if image is an integer and has fewer than 50 unique values.

Source code in src/nnunet_serve/nnunet_api_utils.py
def might_be_mask(self, image: sitk.Image) -> bool:
    """
    Makes an educated guess on whether ``image`` is a mask.

    Args:
        image (sitk.Image): image.

    Returns:
        Boolean which is True if image is an integer _and_ has fewer than
            50 unique values.
    """
    is_mask = False
    if "int" in image.GetPixelIDTypeAsString().lower():
        stats_filter = sitk.LabelShapeStatisticsImageFilter()
        stats_filter.Execute(image)
        if stats_filter.GetNumberOfLabels() < 50:
            is_mask = True
    return is_mask

post_process(volume, equal=None, index=None)

Apply an optional post-processing operation to a loaded volume.

Parameters:

Name Type Description Default
volume Image

Input SimpleITK image.

required
equal int | None

If provided, returns (volume == equal) cast to sitkInt32.

None
index int | None

If provided (and equal is None), returns volume[index].

None

Returns:

Type Description
Image

The processed (or original) image.

Source code in src/nnunet_serve/nnunet_api_utils.py
def post_process(
    self,
    volume: sitk.Image,
    equal: int | None = None,
    index: int | None = None,
) -> sitk.Image:
    """
    Apply an optional post-processing operation to a loaded volume.

    Args:
        volume (sitk.Image): Input SimpleITK image.
        equal (int | None): If provided, returns `(volume == equal)` cast
            to `sitkInt32`.
        index (int | None): If provided (and `equal` is `None`), returns
            `volume[index]`.

    Returns:
        The processed (or original) image.
    """
    if equal is not None:
        return volume == equal
    elif index is not None:
        select_f = sitk.VectorIndexSelectionCastImageFilter()
        select_f.SetIndex(index)
        return select_f.Execute(volume)
    return sitk.Cast(volume, sitk.sitkFloat32)

register(image, stage, image_file_name='prediction.nii.gz')

Registers an image for a specific stage.

Parameters:

Name Type Description Default
image Image

The image to register.

required
stage int

The stage index to associate the image with.

required
image_file_name str

image file name presumed to be inside the stage directory. Defaults to "prediction.nii.gz".

'prediction.nii.gz'
Source code in src/nnunet_serve/nnunet_api_utils.py
def register(
    self,
    image: sitk.Image,
    stage: int,
    image_file_name: str = "prediction.nii.gz",
):
    """Registers an image for a specific stage.

    Args:
        image (sitk.Image): The image to register.
        stage (int): The stage index to associate the image with.
        image_file_name (str, optional): image file name presumed to be inside
            the stage directory. Defaults to "prediction.nii.gz".
    """
    key = os.path.join(f"stage_{stage}", image_file_name)
    self[key] = image

filter_labels(image, class_idx, binarize=False)

Filters labels in an image.

Parameters:

Name Type Description Default
image Image

Image to filter.

required
class_idx int | list[int] | None

List of class indices to keep.

required
binarize bool

Whether to binarize the image. Defaults to False.

False

Returns:

Type Description
Image

sitk.Image: Filtered image.

Source code in src/nnunet_serve/nnunet_api_utils.py
def filter_labels(
    image: sitk.Image, class_idx: int | list[int] | None, binarize: bool = False
) -> sitk.Image:
    """
    Filters labels in an image.

    Args:
        image (sitk.Image): Image to filter.
        class_idx (int | list[int] | None): List of class indices to keep.
        binarize (bool, optional): Whether to binarize the image. Defaults to False.

    Returns:
        sitk.Image: Filtered image.
    """
    stats = sitk.LabelShapeStatisticsImageFilter()
    all_labels = []
    stats.Execute(image)
    image = sitk.Cast(image, sitk.sitkLabelUInt16)
    for label in stats.GetLabels():
        all_labels.append(label)
    if class_idx is None:
        class_idx = all_labels
    if isinstance(class_idx, int):
        class_idx = [class_idx]
    mapping = {int(i): int(i) if i in class_idx else 0 for i in all_labels}
    output = sitk.ChangeLabelLabelMap(image, mapping)
    output = sitk.Cast(output, sitk.sitkUInt16)
    if binarize:
        output = sitk.Cast(output > 0, sitk.sitkInt16)
    return output

get_default_params(default_args)

Returns a dict with default parameters. If default_args is a list of dicts, the output will be a dictionary of lists whenever the key is in CASCADE_ARGUMENTS and whose value will be that of the last dictionary otherwise. If default_args is a dict the output will be default_args.

Parameters:

Name Type Description Default
default_args dict | list[dict]

default arguments.

required

Returns:

Name Type Description
dict dict

correctly formatted default arguments.

Raises:

Type Description
ValueError

if default_args is not a dict or list of dicts.

Source code in src/nnunet_serve/nnunet_api_utils.py
def get_default_params(default_args: dict | list[dict]) -> dict:
    """
    Returns a dict with default parameters. If ``default_args`` is a list of
    dicts, the output will be a dictionary of lists whenever the key is in
    ``CASCADE_ARGUMENTS`` and whose value will be that of the last dictionary
    otherwise. If ``default_args`` is a dict the output will be
    ``default_args``.

    Args:
        default_args (dict | list[dict]): default arguments.

    Returns:
        dict: correctly formatted default arguments.

    Raises:
        ValueError: if ``default_args`` is not a dict or list of dicts.
    """

    default_params_request = {
        k: v
        for k, v in InferenceRequestBase(nnunet_id="", output_dir="")
        .model_dump()
        .items()
    }
    args_with_mult_support = CASCADE_ARGUMENTS
    if isinstance(default_args, dict):
        default_params = default_args
    elif isinstance(default_args, list):
        default_params = {}
        for curr_default_args in default_args:
            for k in curr_default_args:
                if k in args_with_mult_support:
                    default_params[k] = []
                else:
                    default_params[k] = curr_default_args[k]
        for k in default_params:
            if k in args_with_mult_support:
                default_params[k] = [
                    d.get(k, default_params_request.get(k, None))
                    for d in default_args
                ]
    else:
        raise ValueError("default_args should either be dict or list")
    return default_params

get_info(dataset_json_path)

Loads an nnUNet dataset JSON path.

Parameters:

Name Type Description Default
dataset_json_path str

path to dataset JSON.

required

Returns:

Name Type Description
dict dict

the dataset JSON.

Source code in src/nnunet_serve/nnunet_api_utils.py
def get_info(dataset_json_path: str) -> dict:
    """
    Loads an nnUNet dataset JSON path.

    Args:
        dataset_json_path (str): path to dataset JSON.

    Returns:
        dict: the dataset JSON.
    """
    with open(dataset_json_path) as o:
        return json.load(o)

get_series_paths(study_path, series_folders, n)

Gets the complete paths for series given a study_path and the names of series_folders. Given n, which is the number of nnUNet models which will be running, this returns different values:

  • When n is None: returns a list of paths, a status message, and a possible error message.
  • When n is not None and n > 0: returns a list of list of paths, a status message and a possible error message.

Parameters:

Name Type Description Default
study_path str

path to study.

required
series_folders list[str] | list[list[str]] | None

series folder names relative to study_path.

required
n int | None

number of nnUNet models to run. If None assumes a single model is run.

required

Returns:

Type Description
tuple[list[str] | list[list[str]], str, str]

tuple[list[str] | list[list[str]], str, str]: A tuple containing: - series_paths (list[str] | list[list[str]]): The resolved complete paths. - status (str): SUCCESS_STATUS or FAILURE_STATUS. - error (str | None): Error message if status is FAILURE_STATUS.

Source code in src/nnunet_serve/nnunet_api_utils.py
def get_series_paths(
    study_path: str,
    series_folders: list[str | Path] | list[list[str | Path]] | None,
    n: int | None,
) -> tuple[list[str] | list[list[str]], str, str]:
    """
    Gets the complete paths for series given a ``study_path`` and the names of
    ``series_folders``. Given ``n``, which is the number of nnUNet models which
    will be running, this returns different values:

    * When ``n is None``: returns a list of paths, a status message, and a
        possible error message.
    * When ``n is not None and n > 0``: returns a list of list of paths, a
        status message and a possible error message.

    Args:
        study_path (str): path to study.
        series_folders (list[str] | list[list[str]] | None): series folder names
            relative to ``study_path``.
        n (int | None): number of nnUNet models to run. If None assumes a single
            model is run.

    Returns:
        tuple[list[str] | list[list[str]], str, str]: A tuple containing:
            - series_paths (list[str] | list[list[str]]): The resolved complete paths.
            - status (str): SUCCESS_STATUS or FAILURE_STATUS.
            - error (str | None): Error message if status is FAILURE_STATUS.
    """

    def _get_path(study_path: str, x: str | Path) -> str:
        if isinstance(x, Path) is False:
            return os.path.join(study_path, x)
        return str(x)

    if series_folders is None:
        return (
            None,
            FAILURE_STATUS,
            "series_folders must be defined",
        )
    if n is None:
        series_paths = [_get_path(study_path, x) for x in series_folders]
    else:
        study_path = [study_path for _ in range(n)]
        series_paths = []
        if n != len(series_folders):
            if len(series_folders) == 1:
                series_folders = [series_folders[0] for _ in range(n)]
            else:
                return (
                    None,
                    FAILURE_STATUS,
                    "series_folders and nnunet_id must be the same length",
                )
        for i in range(len(study_path)):
            series_paths.append(
                [_get_path(study_path[i], x) for x in series_folders[i]]
            )

    return series_paths, SUCCESS_STATUS, None

load_predictor(nnunet_path, checkpoint_name, mirroring, device_id, use_folds, min_mem=None)

Loads a nnUNetPredictor instance from a trained model folder. Keeps everything in cache using a time-to-live cache, enabling batch-based operations.

Parameters:

Name Type Description Default
nnunet_path str

Path to the nnUNet model folder.

required
checkpoint_name str

Name of the checkpoint to use.

required
mirroring bool

Whether to use mirroring during inference.

required
device_id int

GPU identifier.

required
use_folds bool

Whether to use folds during inference.

required
min_mem int | None

Minimum amount of free memory required to use the GPU. Defaults to None.

None

Returns:

Name Type Description
nnUNetPredictor nnUNetPredictor

A loaded nnUNetPredictor instance.

Source code in src/nnunet_serve/nnunet_api_utils.py
def load_predictor(
    nnunet_path: str,
    checkpoint_name: str,
    mirroring: bool,
    device_id: int | None,
    use_folds: bool,
    min_mem: int | None = None,
) -> nnUNetPredictor:
    """
    Loads a nnUNetPredictor instance from a trained model folder.
    Keeps everything in cache using a time-to-live cache, enabling batch-based
    operations.

    Args:
        nnunet_path (str): Path to the nnUNet model folder.
        checkpoint_name (str): Name of the checkpoint to use.
        mirroring (bool): Whether to use mirroring during inference.
        device_id (int): GPU identifier.
        use_folds (bool): Whether to use folds during inference.
        min_mem (int | None): Minimum amount of free memory required to use the
            GPU. Defaults to None.

    Returns:
        nnUNetPredictor: A loaded nnUNetPredictor instance.
    """
    if isinstance(use_folds, int):
        use_folds = [use_folds]
    if device_id is None:
        if min_mem is None:
            raise ValueError("min_mem must be defined when device_id is None")
        device_id = wait_for_gpu(min_mem)
    args_hash = hash(
        tuple(
            [
                str(x)
                for x in (
                    mirroring,
                    device_id,
                    nnunet_path,
                    use_folds,
                    checkpoint_name,
                )
            ]
        )
    )
    if args_hash in CACHE:
        return CACHE[args_hash]
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=mirroring,
        device=torch.device("cuda", device_id),
        perform_everything_on_device=True,
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True,
    )
    if isinstance(checkpoint_name, CheckpointName):
        checkpoint_name = checkpoint_name.value
    with torch.serialization.safe_globals(SAFE_GLOBALS):
        predictor.initialize_from_trained_model_folder(
            nnunet_path,
            use_folds=use_folds,
            checkpoint_name=checkpoint_name,
        )
    CACHE[args_hash] = predictor
    return predictor

multi_model_inference(nnunet_path, series_paths, class_idx=None, mirroring=False, device_id=None, checkpoint_name='checkpoint_best.pth', is_dicom=False, use_folds=(0,), proba_threshold=0.1, min_confidence=None, intersect_with=None, min_intersection=0.1, crop_from=None, crop_padding=None, cascade_mode='intersect', remove_objects_smaller_than=None, flip_xy=False, bvalue_for_filtering=None, min_mem=None)

Prediction wraper for multiple models. Exports the outputs.

Parameters:

Name Type Description Default
nnunet_path str | list[str]

path or paths to nnUNet models.

required
series_paths list[str] | list[list[str]]

list of paths or list of list of paths corresponding to series.

required
class_idx int | list[int] | list[list[int]]

class index to export probability maps. Defaults to 1.

None
checkpoint_name str | list[str]

name of nnUNet checkpoint. Defaults to "checkpoint_best.pth".

'checkpoint_best.pth'
mirroring bool

whether to use mirroring during inference. Defaults to False.

False
device_id int | None

GPU identifier. Defaults to None (gets automatically assigned to the GPU with the most free memory).

None
is_dicom bool

whether the input/output is DICOM. Defaults to False.

False
use_folds tuple[int]

which folds should be used. Defaults to (0,).

(0,)
proba_threshold float | tuple[float] | list[float]

probability threshold to consider a pixel positive. Defaults to 0.1.

0.1
min_confidence float | tuple[float] | list[float] | None

minimum confidence to keep an object. Defaults to None.

None
intersect_with str | Image | None

whether the prediction should intersect with a given object. Defaults to None.

None
min_intersection float

fraction of prediction which should intersect with intersect_with. Defaults to 0.1.

0.1
crop_from str | Image | None

whether the input should be cropped centered on a given mask object. Defaults to None.

None
crop_padding tuple[int, int, int] | None

padding to be added to the cropped region. Defaults to None.

None
cascade_mode str | list[str]

whether to crop inputs to consecutive bounding boxes or to intersect consecutive outputs. Defaults to "intersect".

'intersect'
remove_objects_smaller_than float | tuple[float] | list[float] | None

whether to remove objects smaller than this threshold. If a float is provided, it is considered as a percentage of the maximum object size. Defaults to None.

None
flip_xy bool

whether to flip the x and y axes of the input. TotalSegmentator does this for some reason. Defaults to False.

False
bvalue_for_filtering int | None

b-value to filter DICOM files by. Defaults to None.

None
min_mem int | None

minimum amount of free memory required to use the GPU. Only used when device_id is None. Defaults to None.

None
Source code in src/nnunet_serve/nnunet_api_utils.py
def multi_model_inference(
    nnunet_path: str | list[str],
    series_paths: list[str] | list[list[str]],
    class_idx: int | list[int] | list[list[int]] | None = None,
    mirroring: bool = False,
    device_id: int | None = None,
    checkpoint_name: str | list[str] = "checkpoint_best.pth",
    is_dicom: bool = False,
    use_folds: tuple[int] = (0,),
    proba_threshold: float | tuple[float] | list[float] = 0.1,
    min_confidence: float | tuple[float] | list[float] | None = None,
    intersect_with: str | sitk.Image | None = None,
    min_intersection: float = 0.1,
    crop_from: str | sitk.Image | None = None,
    crop_padding: tuple[int, int, int] | None = None,
    cascade_mode: str | list[str] = "intersect",
    remove_objects_smaller_than: (
        float | tuple[float] | list[float] | None
    ) = None,
    flip_xy: bool = False,
    bvalue_for_filtering: int | None = None,
    min_mem: int | None = None,
):
    """
    Prediction wraper for multiple models. Exports the outputs.

    Args:
        nnunet_path (str | list[str]): path or paths to nnUNet models.
        series_paths (list[str] | list[list[str]]): list of paths or list of
            list of paths corresponding to series.
        class_idx (int | list[int] | list[list[int]], optional): class index to
            export probability maps. Defaults to 1.
        checkpoint_name (str | list[str], optional): name of nnUNet checkpoint.
            Defaults to "checkpoint_best.pth".
        mirroring (bool, optional): whether to use mirroring during inference.
            Defaults to False.
        device_id (int | None, optional): GPU identifier. Defaults to None (gets
            automatically assigned to the GPU with the most free memory).
        is_dicom (bool, optional): whether the input/output is DICOM. Defaults
            to False.
        use_folds (tuple[int], optional): which folds should be used. Defaults
            to (0,).
        proba_threshold (float | tuple[float] | list[float], optional):
            probability threshold to consider a pixel positive. Defaults to 0.1.
        min_confidence (float | tuple[float] | list[float] | None, optional):
            minimum confidence to keep an object. Defaults to None.
        intersect_with (str | sitk.Image | None, optional): whether the
            prediction should intersect with a given object. Defaults to None.
        min_intersection (float, optional): fraction of prediction which should
            intersect with ``intersect_with``. Defaults to 0.1.
        crop_from (str | sitk.Image | None, optional): whether the
            input should be cropped centered on a given mask object. Defaults to None.
        crop_padding (tuple[int, int, int] | None, optional): padding to be
            added to the cropped region. Defaults to None.
        cascade_mode (str | list[str], optional): whether to crop inputs to consecutive
            bounding boxes or to intersect consecutive outputs. Defaults to "intersect".
        remove_objects_smaller_than (float | tuple[float] | list[float] | None, optional):
            whether to remove objects smaller than this threshold. If a float is provided,
            it is considered as a percentage of the maximum object size. Defaults to None.
        flip_xy (bool, optional): whether to flip the x and y axes of the input.
            TotalSegmentator does this for some reason. Defaults to False.
        bvalue_for_filtering (int | None, optional): b-value to filter DICOM files by.
            Defaults to None.
        min_mem (int | None, optional): minimum amount of free memory required to
            use the GPU. Only used when ``device_id`` is None. Defaults to None.
    """

    def coherce_to_list(obj: Any, n: int) -> list[Any] | tuple[Any]:
        if isinstance(obj, (list, tuple)):
            if len(obj) != n:
                if len(obj) == 1:
                    obj = obj * n
                else:
                    raise ValueError(f"{obj} should have length {n}")
        else:
            obj = [obj for _ in range(n)]
        return obj

    if isinstance(series_paths, (tuple, list)) is False:
        raise ValueError(
            f"series_paths should be list of strings or list of list of strings (is {series_paths})"
        )
    if isinstance(nnunet_path, (list, tuple)):
        # minimal input parsing
        series_paths_list = None
        class_idx_list = None
        if isinstance(series_paths, (tuple, list)):
            if isinstance(series_paths[0], (list, tuple)):
                series_paths_list = series_paths
            elif isinstance(series_paths[0], str):
                series_paths_list = [series_paths for _ in nnunet_path]
        if isinstance(class_idx, int) or class_idx is None:
            class_idx_list = [class_idx for _ in nnunet_path]
        elif isinstance(class_idx, (tuple, list)):
            class_idx_list = class_idx
        mirroring = coherce_to_list(mirroring, len(nnunet_path))
        proba_threshold = coherce_to_list(proba_threshold, len(nnunet_path))
        min_confidence = coherce_to_list(min_confidence, len(nnunet_path))
        remove_objects_smaller_than = coherce_to_list(
            remove_objects_smaller_than, len(nnunet_path)
        )
        checkpoint_name_list = coherce_to_list(
            checkpoint_name, len(nnunet_path)
        )
        use_folds = coherce_to_list(use_folds, len(nnunet_path))

        if series_paths_list is None:
            raise ValueError(
                f"series_paths should be list of strings or list of list of strings (is {series_paths})"
            )

        logger.info("Using nnunet_path %s for inference", nnunet_path)
        logger.info("Using series_paths %s for inference", series_paths)
        logger.info("Using class_idx_list %s for inference", class_idx_list)
        logger.info("Using proba_threshold %s for inference", proba_threshold)
        logger.info("Using min_confidence %s for inference", min_confidence)
        logger.info("Using checkpoint_name %s for inference", checkpoint_name)

        series_loader = SeriesLoader(
            series_paths, is_dicom, bvalue_for_filtering=bvalue_for_filtering
        )
        all_predictions = []
        all_proba_maps = []
        intersect_with_class_idx = 1
        crop_class_idx = 1
        for i in range(len(nnunet_path)):
            mask, proba_map = single_model_inference(
                nnunet_path=nnunet_path[i],
                volumes=series_loader.get_volumes(i),
                class_idx=class_idx_list[i],
                checkpoint_name=checkpoint_name_list[i],
                mirroring=mirroring[0],
                device_id=device_id,
                use_folds=use_folds[i],
                proba_threshold=proba_threshold[i],
                min_confidence=min_confidence[i],
                intersect_with=intersect_with,
                intersect_with_class_idx=intersect_with_class_idx,
                min_intersection=min_intersection,
                crop_from=crop_from,
                crop_class_idx=crop_class_idx,
                crop_padding=crop_padding,
                remove_objects_smaller_than=remove_objects_smaller_than[i],
                flip_xy=flip_xy[i],
                min_mem=min_mem,
            )
            all_predictions.append(mask)
            all_proba_maps.append(proba_map)
            series_loader.register(mask, i)
            if proba_map:
                series_loader.register(
                    proba_map, i, image_file_name="probabilities.nii.gz"
                )
            if i < (len(nnunet_path) - 1):
                if "intersect" in cascade_mode:
                    logger.info("Using mask for intersection")
                    intersect_with = mask
                    intersect_with_class_idx = class_idx_list[i]
                if "crop" in cascade_mode:
                    logger.info("Using mask for cropping")
                    crop_from = mask
                    crop_class_idx = class_idx_list[i]
        # keep first from last predicted series to replicate previous behaviour
        if is_dicom:
            good_file_paths = [series_loader.get_file_paths(-1)[0]]
        else:
            good_file_paths = None
    else:
        series_loader = SeriesLoader(
            series_paths, is_dicom, bvalue_for_filtering=bvalue_for_filtering
        )
        mask, proba_map = single_model_inference(
            nnunet_path=nnunet_path,
            volumes=series_loader.get_volumes(0),
            checkpoint_name=checkpoint_name,
            mirroring=mirroring,
            device_id=device_id,
            use_folds=use_folds,
            proba_threshold=proba_threshold,
            min_confidence=min_confidence,
            intersect_with=intersect_with,
            min_intersection=min_intersection,
            crop_from=crop_from,
            crop_padding=crop_padding,
            remove_objects_smaller_than=remove_objects_smaller_than,
            flip_xy=flip_xy,
            min_mem=min_mem,
        )
        all_predictions = [mask]
        all_proba_maps = [proba_map]

    logger.info("Finished inference")

    all_volumes = [
        series_loader.get_volumes(i) for i in range(series_loader.n_stages)
    ]
    return all_predictions, all_proba_maps, good_file_paths, all_volumes

predict(series_paths, metadata, mirroring, device_id, params, nnunet_path, flip_xy=False, writing_process_pool=None)

Runs the prediction for a set of models and returns exported output paths.

Parameters:

Name Type Description Default
series_paths list

paths to series.

required
metadata str

DICOM seg metadata. Has to be a dict with either "path" (pointing towards a DCMQI metadata file) or a list of metadata key-value pairs (please see SegWriter for details).

required
mirroring bool

whether to use mirroring during inference.

required
device_id int | None

GPU identifier.

required
params dict

parameters which will be used in wraper.

required
nnunet_path str | list[str]

path or paths to nnUNet model.

required
flip_xy bool | list[bool]

whether to flip the x and y axes during inference. Defaults to False.

False
writing_process_pool ProcessPool | None

process pool to use for parallel file saving operations. Defaults to None.

None

Returns:

Name Type Description
dict dict

mapping of output artifact keys to lists of paths.

Source code in src/nnunet_serve/nnunet_api_utils.py
def predict(
    series_paths: list,
    metadata: str | None,
    mirroring: bool,
    device_id: int | None,
    params: dict,
    nnunet_path: str | list[str],
    flip_xy: bool | list[bool] = False,
    writing_process_pool: ProcessPool | None = None,
) -> dict:
    """
    Runs the prediction for a set of models and returns exported output paths.

    Args:
        series_paths (list): paths to series.
        metadata (str): DICOM seg metadata. Has to be a dict with either "path"
            (pointing towards a DCMQI metadata file) or a list of metadata
            key-value pairs (please see ``SegWriter`` for details).
        mirroring (bool): whether to use mirroring during inference.
        device_id (int | None): GPU identifier.
        params (dict): parameters which will be used in wraper.
        nnunet_path (str | list[str]): path or paths to nnUNet model.
        flip_xy (bool | list[bool]): whether to flip the x and y axes during
            inference. Defaults to False.
        writing_process_pool (ProcessPool | None): process pool to use for parallel
            file saving operations. Defaults to None.

    Returns:
        dict: mapping of output artifact keys to lists of paths.
    """
    if not torch.cuda.is_available():
        raise RuntimeError(
            "CUDA is not available but GPU inference was requested"
        )

    inference_param_names = [
        "class_idx",
        "checkpoint_name",
        "is_dicom",
        "use_folds",
        "proba_threshold",
        "min_confidence",
        "intersect_with",
        "crop_from",
        "crop_padding",
        "min_intersection",
        "cascade_mode",
        "flip_xy",
        "bvalue_for_filtering",
    ]
    export_param_names = [
        "output_dir",
        "suffix",
        "is_dicom",
        "save_proba_map",
        "save_nifti_inputs",
        "save_rt_struct_output",
        "class_idx",
    ]
    delete_params = [
        "nnunet_id",
        "tta",
        "min_mem",
        "aliases",
        "study_path",
        "series_folders",
    ]

    min_mem = params.get("min_mem", None)
    params = {k: params[k] for k in params if k not in delete_params}
    inference_params = {
        k: params[k] for k in params if k in inference_param_names
    }
    export_params = {k: params[k] for k in params if k in export_param_names}

    if export_params["is_dicom"] is True:
        if metadata is None:
            raise ValueError("metadata must be defined when is_dicom is True")
        if isinstance(metadata, list):
            seg_writers = [
                SegWriter.init_from_metadata_dict(m) for m in metadata
            ]
        else:
            seg_writers = [SegWriter.init_from_metadata_dict(metadata)]
    else:
        seg_writers = None

    try:
        (
            all_predictions,
            all_proba_maps,
            good_file_paths,
            all_volumes,
        ) = multi_model_inference(
            series_paths=series_paths,
            nnunet_path=nnunet_path,
            flip_xy=flip_xy,
            mirroring=mirroring,
            device_id=device_id,
            min_mem=min_mem,
            **inference_params,
        )
    except torch.cuda.OutOfMemoryError as e:
        raise RuntimeError(
            "CUDA out of memory during inference. Consider reducing input size or TTA."
        ) from e
    except FileNotFoundError as e:
        raise RuntimeError(f"Input file or model path not found: {e}") from e
    except Exception as e:
        raise RuntimeError(f"Inference failed: {e}") from e

    identifiers = []
    label_filter = sitk.LabelShapeStatisticsImageFilter()
    is_empty = []
    for pred in all_predictions:
        label_filter.Execute(pred)
        is_empty.append(label_filter.GetNumberOfLabels() == 0)

    try:
        if writing_process_pool:
            for i in range(len(all_predictions)):
                identifier = str(uuid.uuid4())
                identifiers.append(identifier)
                writing_process_pool.put(
                    identifier=identifier,
                    args=[],
                    kwargs={
                        "masks": [all_predictions[i]],
                        "proba_maps": [all_proba_maps[i]],
                        "good_file_paths": good_file_paths,
                        "volumes": [all_volumes[i]],
                        "seg_writers": [seg_writers[i]],
                        **export_params,
                    },
                )
            output_paths = {}
        else:
            output_paths = export_predictions(
                masks=all_predictions,
                proba_maps=all_proba_maps,
                good_file_paths=good_file_paths,
                volumes=all_volumes,
                seg_writers=seg_writers,
                **export_params,
            )
            logger.info("Finished exporting predictions")
    except Exception as e:
        raise RuntimeError(f"Failed to export predictions: {e}") from e

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    return output_paths, identifiers, is_empty

predict_from_data_iterator_local(predictor, data_iterator, save_probabilities=False, class_idx=None)

Adapts the original predict_from_data_iterator to use no multiprocessing.

Source code in src/nnunet_serve/nnunet_api_utils.py
def predict_from_data_iterator_local(
    predictor: nnUNetPredictor,
    data_iterator: list[dict[str, Any]],
    save_probabilities: bool = False,
    class_idx: int | list[int] | None = None,
):
    """
    Adapts the original predict_from_data_iterator to use no multiprocessing.
    """
    ret = []
    label_adapter = LabelManagerAdapter(predictor.label_manager)
    for preprocessed in data_iterator:
        data = preprocessed["data"]
        if isinstance(data, str):
            delfile = data
            data = torch.from_numpy(np.load(data))
            os.remove(delfile)

        properties = preprocessed["data_properties"]

        prediction = predictor.predict_logits_from_preprocessed_data(data).cpu()
        n_classes = prediction.shape[0]
        logger.info("nnUNet: predicted logits")
        subset_state = None
        if class_idx is not None:
            subset_state = label_adapter.build_subset_state(class_idx)
            prediction = prediction[subset_state.prediction_indices]

        logger.info("nnUNet: resampling...")
        if subset_state is not None:
            with label_adapter.apply_subset_state(subset_state):
                processed_output = (
                    convert_predicted_logits_to_segmentation_with_correct_shape(
                        predicted_logits=prediction,
                        plans_manager=predictor.plans_manager,
                        configuration_manager=predictor.configuration_manager,
                        label_manager=predictor.label_manager,
                        properties_dict=properties,
                        return_probabilities=save_probabilities,
                    )
                )
        else:
            processed_output = (
                convert_predicted_logits_to_segmentation_with_correct_shape(
                    predicted_logits=prediction,
                    plans_manager=predictor.plans_manager,
                    configuration_manager=predictor.configuration_manager,
                    label_manager=predictor.label_manager,
                    properties_dict=properties,
                    return_probabilities=save_probabilities,
                )
            )
        if class_idx is not None:
            if save_probabilities is False:
                processed_output = np.vectorize(
                    subset_state.correspondence_dict.get
                )(processed_output)
            else:
                corrected_mask = np.vectorize(
                    subset_state.correspondence_dict.get
                )(processed_output[0])
                corrected_prob = np.zeros(
                    [n_classes, *processed_output[0].shape]
                )
                for i, label in enumerate(subset_state.used_labels[1:]):
                    corrected_prob[label] = processed_output[1][i + 1]
                processed_output = (corrected_mask, corrected_prob)

        logger.info("nnUNet: converted logits to segmentation")
        ret.append(processed_output)
        logger.info(f"Done with image of shape {data.shape}")

    if isinstance(data_iterator, MultiThreadedAugmenter):
        data_iterator._finish()

    # clear lru cache
    compute_gaussian.cache_clear()
    # clear device cache
    empty_cache(predictor.device)
    return ret

process_mask(mask_array, input_image, intersect_with=None, min_intersection=0.1, output_padding=None)

Processes a mask array and returns a SITK mask image.

Parameters:

Name Type Description Default
mask_array ndarray

an array corresponding to a mask.

required
intersect_with str | Image

calculates the intersection of each candidate with the image specified in intersect_with. If the intersection is larger than min_intersection, the candidate is kept; otherwise it is discarded. Defaults to None.

None
min_intersection float

minimum intersection over the union to keep candidate. Defaults to 0.1.

0.1
output_padding list[int] | None

padding to apply to the output mask. Defaults to None.

None

Returns:

Type Description
tuple[Image, bool]

sitk.Image: returns the probability mask after the candidate extraction protocol.

Source code in src/nnunet_serve/nnunet_api_utils.py
def process_mask(
    mask_array: np.ndarray,
    input_image: sitk.Image,
    intersect_with: str | sitk.Image | None = None,
    min_intersection: float = 0.1,
    output_padding: tuple[int, int, int, int, int, int] | None = None,
) -> tuple[sitk.Image, bool]:
    """
    Processes a mask array and returns a SITK mask image.

    Args:
        mask_array (np.ndarray): an array corresponding to a mask.
        intersect_with (str | sitk.Image, optional): calculates the
            intersection of each candidate with the image specified in
            intersect_with. If the intersection is larger than
            min_intersection, the candidate is kept; otherwise it is discarded.
            Defaults to None.
        min_intersection (float, optional): minimum intersection over the union
            to keep candidate. Defaults to 0.1.
        output_padding (list[int] | None, optional): padding to apply to the
            output mask. Defaults to None.

    Returns:
        sitk.Image: returns the probability mask after the candidate extraction
            protocol.
    """
    logger.info("Processing mask")
    empty = False
    if intersect_with is not None:
        mask_array = intersect(mask_array, intersect_with, min_intersection)
    mask = sitk.GetImageFromArray(mask_array)
    mask.CopyInformation(input_image)
    mask_stats = sitk.LabelShapeStatisticsImageFilter()
    mask_stats.Execute(mask)
    logger.info("Labels: %s", mask_stats.GetLabels())
    if len(mask_stats.GetLabels()) == 0:
        logger.warning("Mask is empty")
        empty = True
    if output_padding is not None:
        pad_filter = sitk.ConstantPadImageFilter()
        pad_filter.SetPadLowerBound(output_padding[:3])
        pad_filter.SetPadUpperBound(output_padding[3:])
        mask = pad_filter.Execute(mask)
    return mask, empty

process_proba_array(proba_array, input_image, proba_threshold=0.1, min_confidence=0.5, intersect_with=None, min_intersection=0.1, class_idx=None, output_padding=None)

Exports a SITK probability mask and the corresponding probability map. Applies a candidate extraction protocol (threshold, CC analysis, min_confidence).

Parameters:

Name Type Description Default
proba_array ndarray

an array corresponding to a probability map.

required
proba_threshold float

sets values below this value to 0.

0.1
min_confidence float

removes objects whose maximum probability is lower than this value.

0.5
intersect_with str | Image

calculates the intersection of each candidate with the image specified in intersect_with. If the intersection is larger than min_intersection, the candidate is kept; otherwise it is discarded. Defaults to None.

None
min_intersection float

minimum intersection over the union to keep candidate. Defaults to 0.1.

0.1
class_idx int | list[int] | None

class index for output probability. Defaults to None (no selection).

None
output_padding list[int] | None

padding to apply to the output mask. Defaults to None.

None

Returns:

Type Description
tuple[Image, Image, bool]

tuple[sitk.Image, sitk.Image, bool]: (mask, proba_map, empty_flag)

Source code in src/nnunet_serve/nnunet_api_utils.py
def process_proba_array(
    proba_array: np.ndarray,
    input_image: sitk.Image,
    proba_threshold: float = 0.1,
    min_confidence: float = 0.5,
    intersect_with: str | sitk.Image | None = None,
    min_intersection: float = 0.1,
    class_idx: int | list[int] | None = None,
    output_padding: list[int] | None = None,
) -> tuple[sitk.Image, sitk.Image, bool]:
    """
    Exports a SITK probability mask and the corresponding probability map.
    Applies a candidate extraction protocol (threshold, CC analysis, min_confidence).

    Args:
        proba_array (np.ndarray): an array corresponding to a probability map.
        proba_threshold (float, optional): sets values below this value to 0.
        min_confidence (float, optional): removes objects whose maximum
            probability is lower than this value.
        intersect_with (str | sitk.Image, optional): calculates the
            intersection of each candidate with the image specified in
            intersect_with. If the intersection is larger than
            min_intersection, the candidate is kept; otherwise it is discarded.
            Defaults to None.
        min_intersection (float, optional): minimum intersection over the union
            to keep candidate. Defaults to 0.1.
        class_idx (int | list[int] | None, optional): class index for output
            probability. Defaults to None (no selection).
        output_padding (list[int] | None, optional): padding to apply to the
            output mask. Defaults to None.

    Returns:
        tuple[sitk.Image, sitk.Image, bool]: (mask, proba_map, empty_flag)
    """
    logger.info("Exporting probability map and mask")
    empty = False
    if class_idx is None:
        mask = np.argmax(proba_array, 0)
        proba_array = np.moveaxis(proba_array, 0, -1)
        mask = sitk.GetImageFromArray(mask.astype(np.uint32))
        mask.CopyInformation(input_image)
        proba_map = sitk.GetImageFromArray(proba_array)
        proba_map = copy_information_nd(proba_map, input_image)
    else:
        if isinstance(class_idx, int):
            proba_array = proba_array[class_idx]
        elif isinstance(class_idx, (list, tuple)):
            proba_array = proba_array[class_idx].sum(0)
        proba_array = np.where(proba_array < proba_threshold, 0.0, proba_array)
        proba_array, _, _ = extract_lesion_candidates(
            proba_array,
            threshold=proba_threshold,
            min_confidence=min_confidence,
            intersect_with=intersect_with,
            min_intersection=min_intersection,
        )
        mask = np.where(
            proba_array > proba_threshold, int(np.min(class_idx)), 0
        )
        proba_map = sitk.GetImageFromArray(proba_array)
        proba_map = copy_information_nd(proba_map, input_image)
        mask = sitk.GetImageFromArray(mask.astype(np.uint32))
        mask.CopyInformation(input_image)
    mask_stats = sitk.LabelShapeStatisticsImageFilter()
    mask_stats.Execute(mask)
    if len(mask_stats.GetLabels()) == 0:
        logger.warning("Mask is empty")
        empty = True
    if output_padding is not None:
        pad_filter = sitk.ConstantPadImageFilter()
        pad_filter.SetPadLowerBound(output_padding[:3])
        pad_filter.SetPadUpperBound(output_padding[3:])
        mask = pad_filter.Execute(mask)
        proba_map = pad_filter.Execute(proba_map)

    return mask, proba_map, empty

single_model_inference(nnunet_path, volumes, class_idx=None, checkpoint_name='checkpoint_best.pth', mirroring=False, device_id=None, use_folds=[0], proba_threshold=None, min_confidence=None, intersect_with=None, intersect_with_class_idx=1, crop_from=None, crop_class_idx=1, crop_padding=None, min_intersection=0.1, remove_objects_smaller_than=None, flip_xy=False, min_mem=None)

Runs the inference for a single model.

Parameters:

Name Type Description Default
nnunet_path str

path to nnUNet model.

required
volumes list[Image]

series volumes.

required
class_idx int | list[int]

class index for probability output. Defaults to 1.

None
checkpoint_name str

name of checkpoint in nnUNet model. Defaults to "checkpoint_best.pth".

'checkpoint_best.pth'
mirroring bool

whether to use mirroring during inference. Defaults to False.

False
device_id int | None

GPU identifier. Defaults to None (gets automatically assigned to the GPU with the most free memory).

None
use_folds list[int]

which folds from the nnUNet model will be used. Defaults to [0].

[0]
proba_threshold float

probability threshold to consider a pixel positive positive. Defaults to 0.1.

None
min_confidence float | None

minimum confidence level for each detected object. Defaults to None.

None
intersect_with str | Image | None

whether the prediction should intersect with a given object. Defaults to None.

None
intersect_with_class_idx int | None

class index for intersection. Defaults to None.

1
crop_from str | Image | None

whether the input should be cropped centered on a given mask object. If specified as a string, it can be either the path or the path:class_idx. Defaults to None.

None
crop_class_idx int | None

class index for cropping. Defaults to None.

1
crop_padding tuple[int, int, int] | None

padding to be added to the cropped region. Defaults to None.

None
min_intersection float

fraction of prediction which should intersect with intersect_with. Defaults to 0.1.

0.1
remove_objects_smaller_than float | None

whether to remove objects smaller than this threshold. If a float is provided, it is considered as a percentage of the maximum object size. Defaults to None.

None
flip_xy bool

whether to flip the x and y axes of the input. TotalSegmentator does this for some reason. Defaults to False.

False
min_mem int | None

minimum amount of free memory required to use the GPU. Only used when device_id is None. Defaults to None.

None

Raises:

Type Description
ValueError

if there is a mismatch between the number of series and the number of channels in the model.

Returns:

Type Description
tuple[list[str], str, list[list[str]], Image]

tuple[list[str], str, list[list[str]], sitk.Image]: prediction files, path to output mask, good DICOM file paths, probability map.

Source code in src/nnunet_serve/nnunet_api_utils.py
def single_model_inference(
    nnunet_path: str,
    volumes: list[sitk.Image],
    class_idx: int | list[int] | None = None,
    checkpoint_name: str = "checkpoint_best.pth",
    mirroring: bool = False,
    device_id: int | None = None,
    use_folds: list[int] = [0],
    proba_threshold: float | None = None,
    min_confidence: float | None = None,
    intersect_with: str | sitk.Image | None = None,
    intersect_with_class_idx: int = 1,
    crop_from: str | sitk.Image | None = None,
    crop_class_idx: int = 1,
    crop_padding: tuple[int, int, int] | None = None,
    min_intersection: float = 0.1,
    remove_objects_smaller_than: float | None = None,
    flip_xy: bool = False,
    min_mem: int | None = None,
) -> tuple[list[str], str, list[list[str]], sitk.Image]:
    """
    Runs the inference for a single model.

    Args:
        nnunet_path (str): path to nnUNet model.
        volumes (list[sitk.Image]): series volumes.
        class_idx (int | list[int], optional): class index for probability
            output. Defaults to 1.
        checkpoint_name (str, optional): name of checkpoint in nnUNet model.
            Defaults to "checkpoint_best.pth".
        mirroring (bool, optional): whether to use mirroring during inference.
            Defaults to False.
        device_id (int | None, optional): GPU identifier. Defaults to None (gets
            automatically assigned to the GPU with the most free memory).
        use_folds (list[int], optional): which folds from the nnUNet model will be
            used. Defaults to [0].
        proba_threshold (float, optional): probability threshold to consider a
            pixel positive positive. Defaults to 0.1.
        min_confidence (float | None, optional): minimum confidence level for
            each detected object. Defaults to None.
        intersect_with (str | sitk.Image | None, optional): whether the
            prediction should intersect with a given object. Defaults to None.
        intersect_with_class_idx (int | None, optional): class index for intersection.
            Defaults to None.
        crop_from (str | sitk.Image | None, optional): whether the
            input should be cropped centered on a given mask object. If
            specified as a string, it can be either the path or the path:class_idx.
            Defaults to None.
        crop_class_idx (int | None, optional): class index for cropping. Defaults to None.
        crop_padding (tuple[int, int, int] | None, optional): padding to be
            added to the cropped region. Defaults to None.
        min_intersection (float, optional): fraction of prediction which should
            intersect with ``intersect_with``. Defaults to 0.1.
        remove_objects_smaller_than (float | None, optional): whether to remove
            objects smaller than this threshold. If a float is provided, it is
            considered as a percentage of the maximum object size. Defaults to None.
        flip_xy (bool, optional): whether to flip the x and y axes of the input.
            TotalSegmentator does this for some reason. Defaults to False.
        min_mem (int | None, optional): minimum amount of free memory required to
            use the GPU. Only used when ``device_id`` is None. Defaults to None.

    Raises:
        ValueError: if there is a mismatch between the number of series and
            the number of channels in the model.

    Returns:
        tuple[list[str], str, list[list[str]], sitk.Image]: prediction files,
            path to output mask, good DICOM file paths, probability map.
    """

    if isinstance(proba_threshold, list):
        proba_threshold = proba_threshold[0]

    # initializes the network architecture, loads the checkpoint
    predictor = load_predictor(
        nnunet_path=nnunet_path,
        checkpoint_name=checkpoint_name,
        mirroring=mirroring,
        device_id=device_id,
        use_folds=use_folds,
        min_mem=min_mem,
    )
    input_spacing = volumes[0].GetSpacing()
    original_size = volumes[0].GetSize()
    original_direction = volumes[0].GetDirection()
    original_origin = volumes[0].GetOrigin()
    spacing = predictor.configuration_manager.spacing[::-1]

    predictor.dataset_json["file_ending"] = ".nii.gz"
    exp_chan = predictor.dataset_json["channel_names"]
    if len(exp_chan) != len(volumes):
        raise ValueError(
            f"series_paths should have length {len(exp_chan)} ({exp_chan}) but has length {len(volumes)}"
        )
    output_padding = None
    if crop_from is not None:
        if isinstance(crop_from, str):
            crop_from = sitk.ReadImage(crop_from)
        crop_from = filter_labels(crop_from, crop_class_idx, True)
        bb, output_padding = get_crop(crop_from, volumes[0], crop_padding)
        volumes = [
            v[bb[0] : bb[3], bb[1] : bb[4], bb[2] : bb[5]] for v in volumes
        ]

    logger.info("Resampling input images to nnUNet model spacing")
    logger.info("Input size (before resampling): %s", volumes[0].GetSize())

    volumes = [resample_image(volume, spacing) for volume in volumes]
    logger.info("Running inference using %s", nnunet_path)
    logger.info("Folds: %s", use_folds)
    logger.info("Mirroring: %s", mirroring)
    input_array = np.stack([sitk.GetArrayFromImage(v) for v in volumes])
    image_properties = {
        "spacing": volumes[0].GetSpacing()[::-1],
        "origin": volumes[0].GetOrigin(),
        "direction": volumes[0].GetDirection(),
        "size": volumes[0].GetSize(),
    }
    logger.info("Input size: %s", volumes[0].GetSize())
    logger.info("Input spacing: %s", volumes[0].GetSpacing())
    logger.info("Input origin: %s", volumes[0].GetOrigin())
    logger.info("Input direction: %s", volumes[0].GetDirection())
    logger.info("Input shape (array): %s", input_array.shape)
    logger.info("nnUNet: creating data iterator")
    if flip_xy:
        input_array = input_array[:, :, ::-1, ::-1]
    iterator = predictor.get_data_iterator_from_raw_npy_data(
        [input_array], None, [image_properties], None, 1
    )
    logger.info("nnUNet: running inference")
    mask_array, proba_array = predict_from_data_iterator_local(
        predictor,
        iterator,
        save_probabilities=True,
        class_idx=class_idx,
    )[0]
    if flip_xy:
        mask_array = mask_array[..., ::-1, ::-1]
        proba_array = proba_array[..., ::-1, ::-1]

    if remove_objects_smaller_than is not None:
        logger.info("Removing small objects")
        mask_array = small_object_removal(
            mask_array, remove_objects_smaller_than
        )

    logger.info("nnUNet: inference done")

    resample_kwargs = {
        "out_spacing": input_spacing,
        "out_size": original_size,
        "out_direction": original_direction,
        "out_origin": original_origin,
    }
    if intersect_with is not None:
        intersect_with = filter_labels(
            intersect_with, intersect_with_class_idx, True
        )
        if crop_from is not None:
            intersect_with = intersect_with[
                bb[0] : bb[3], bb[1] : bb[4], bb[2] : bb[5]
            ]
        intersect_with = resample_image(intersect_with, spacing, is_mask=True)
    if proba_threshold is not None:
        mask, probability_map, _ = process_proba_array(
            proba_array,
            volumes[0],
            min_confidence=min_confidence,
            proba_threshold=proba_threshold,
            intersect_with=intersect_with,
            min_intersection=min_intersection,
            class_idx=class_idx,
            output_padding=output_padding,
        )
        mask = resample_image(
            mask,
            is_mask=True,
            **resample_kwargs,
        )
        probability_map = resample_image(
            probability_map,
            is_mask=False,
            **resample_kwargs,
        )
    else:
        mask, _ = process_mask(
            mask_array,
            volumes[0],
            intersect_with=intersect_with,
            min_intersection=min_intersection,
            output_padding=output_padding,
        )
        probability_map = None
        mask = resample_image(
            mask,
            is_mask=True,
            **resample_kwargs,
        )

    if probability_map is not None:
        if probability_map.GetNumberOfComponentsPerPixel() > 1:
            logger.info("Treating probability map as vector pixel type")
            f = sitk.LabelShapeStatisticsImageFilter()
            f.Execute(mask)
            select_f = sitk.VectorIndexSelectionCastImageFilter()
            out_proba_map = []
            for i in f.GetLabels():
                select_f.SetIndex(i)
                pmap = select_f.Execute(probability_map)
                pmap = sitk.Cast(mask == i, pmap.GetPixelID()) * pmap
                out_proba_map.append(pmap)
            compose_filter = sitk.ComposeImageFilter()
            probability_map = compose_filter.Execute(*out_proba_map)
        else:
            probability_map = (
                sitk.Cast(mask, probability_map.GetPixelID()) * probability_map
            )

    logger.info("Finished processing masks")

    return mask, probability_map