Skip to content

Segmentation output writers

Utilities for exporting nnU-Net segmentations to disk and DICOM SEG.

This module provides the SegWriter abstraction and supporting helpers to convert raw prediction arrays into persistent representations, including NIfTI-like formats and DICOM SEG objects with appropriate coding, colors, and metadata derived from the input images and model configuration.

SegWriter dataclass

Helper class to manage DICOM Segmentation (SEG) and RT-Struct export.

This class handles the conversion of prediction masks into standardized DICOM formats, including metadata management for algorithms and segments.

Attributes:

Name Type Description
algorithm_name str

Name of the segmentation algorithm.

segment_names list[str | dict[str, str]]

Names or dicts describing the segments.

segment_descriptions list[SegmentDescription]

High-level segment descriptions.

algorithm_version str

Version string of the algorithm.

algorithm_family Code

DICOM code for the algorithm family.

algorithm_type SegmentAlgorithmTypeValues

DICOM algorithm type.

instance_number int

DICOM instance number.

series_number int

DICOM series number.

manufacturer str

Manufacturer name for DICOM metadata.

manufacturer_model_name str

Model name for DICOM metadata.

series_description str

Description for the DICOM series.

clinical_trial_series_id str

Clinical trial series ID.

clinical_trial_time_point_id str

Clinical trial time point ID.

body_part_examined str

Body part examined for DICOM metadata.

validate bool

Whether to perform validation on init.

Source code in src/nnunet_serve/seg_writers.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
@dataclass
class SegWriter:
    """
    Helper class to manage DICOM Segmentation (SEG) and RT-Struct export.

    This class handles the conversion of prediction masks into standardized
    DICOM formats, including metadata management for algorithms and segments.

    Attributes:
        algorithm_name (str): Name of the segmentation algorithm.
        segment_names (list[str | dict[str, str]], optional): Names or dicts
            describing the segments.
        segment_descriptions (list[hd.seg.SegmentDescription], optional):
            High-level segment descriptions.
        algorithm_version (str): Version string of the algorithm.
        algorithm_family (Code): DICOM code for the algorithm family.
        algorithm_type (SegmentAlgorithmTypeValues): DICOM algorithm type.
        instance_number (int): DICOM instance number.
        series_number (int): DICOM series number.
        manufacturer (str): Manufacturer name for DICOM metadata.
        manufacturer_model_name (str): Model name for DICOM metadata.
        series_description (str): Description for the DICOM series.
        clinical_trial_series_id (str): Clinical trial series ID.
        clinical_trial_time_point_id (str): Clinical trial time point ID.
        body_part_examined (str): Body part examined for DICOM metadata.
        validate (bool): Whether to perform validation on init.
    """

    algorithm_name: str
    segment_names: list[str | dict[str, str]] | None = None
    segment_descriptions: list[hd.seg.SegmentDescription] | None = None
    algorithm_version: str = "v1.0"
    algorithm_family: pydicom.sr.coding.Code = (
        codes.CID7162.ArtificialIntelligence
    )
    algorithm_type: hd.seg.SegmentAlgorithmTypeValues = (
        hd.seg.SegmentAlgorithmTypeValues.AUTOMATIC
    )
    instance_number: int = 1
    series_number: int = 999
    manufacturer: str = "Algorithm"
    manufacturer_model_name: str = "AlgorithmModel"
    series_description: str = "Segmentation"
    clinical_trial_series_id: str = "1"
    clinical_trial_time_point_id: str = "1"
    body_part_examined: str = "BODY"
    validate: bool = False

    def __post_init__(self):
        if self.segment_descriptions is None and self.segment_names is None:
            raise ValueError(
                "Either segment_descriptions or segment_names must be provided"
            )
        category_concepts = CATEGORY_CONCEPTS[DEFAULT_SEGMENT_SCHEME]
        self.algorithm_identification = hd.AlgorithmIdentificationSequence(
            name=self.algorithm_name,
            version=self.algorithm_version,
            family=self.algorithm_family,
        )
        if self.segment_names is None:
            return
        if self.segment_descriptions is None:
            self.segment_descriptions = []
        for i, segment in enumerate(self.segment_names):
            type_code, segment_dict = get_segment_type_code(segment, i)
            category_code = [
                category_concepts[k]
                for k in category_concepts
                if category_concepts[k].value == segment_dict["category_number"]
            ][0]
            random_rgb_colour = random_color_generator()
            segment_description = hd.seg.SegmentDescription(
                segment_number=segment_dict["number"],
                segment_label=segment_dict["label"],
                segmented_property_category=category_code,
                segmented_property_type=type_code,
                algorithm_type=self.algorithm_type,
                algorithm_identification=self.algorithm_identification,
                tracking_uid=hd.UID(),
                tracking_id=segment_dict["tracking_id"],
            )
            segment_description.RecommendedDisplayCIELabValue = rgb_to_cielab(
                random_rgb_colour
            )
            if segment_dict["laterality"] is not None:
                lat_code = LATERALITY_CODING[segment_dict["scheme"]][
                    segment_dict["laterality"]
                ]
                segment_description.SegmentedPropertyTypeModifierCodeSequence = [
                    lat_code
                ]
            csi = CODING_SCHEME_INFORMATION[segment_dict["scheme"]]
            for (
                prop_type
            ) in segment_description.SegmentedPropertyTypeCodeSequence:
                for k in csi:
                    prop_type[k] = DataElement(
                        tag=k,
                        VR=CODING_SCHEME_INFORMATION_VR[k],
                        value=csi[k],
                    )
            for (
                prop_cat
            ) in segment_description.SegmentedPropertyCategoryCodeSequence:
                for k in csi:
                    prop_cat[k] = DataElement(
                        tag=k,
                        VR=CODING_SCHEME_INFORMATION_VR[k],
                        value=csi[k],
                    )
            self.segment_descriptions.append(segment_description)

        if self.validate is True:
            for i, segment_description in enumerate(self.segment_descriptions):
                logger.info(
                    "Segment %s: %s, %s, class_idx=%i",
                    segment_description.segment_number,
                    segment_description.segment_label,
                    segment_description.segmented_property_type.meaning,
                    i,
                )

    def to_array_if_necessary(
        self, mask: np.ndarray | sitk.Image
    ) -> np.ndarray:
        """
        Converts a mask to a numpy array if the mask is a sitk.Image.

        Args:
            mask (np.ndarray | sitk.Image): the mask to convert.

        Returns:
            np.ndarray: the mask as a numpy array.
        """
        if isinstance(mask, sitk.Image):
            mask = sitk.GetArrayFromImage(mask)
        return mask

    def make_compliant(self, f: pydicom.Dataset):
        """
        Tries to smooth some potential missing fields.

        Args:
            f (pydicom.Dataset): a pydicom Dataset.

        Returns:
            A compliant pydicom.Dataset.
        """
        if hasattr(f, "PatientSex") is False:
            f.PatientSex = "O"
        if hasattr(f, "AccessionNumber") is False:
            f.AccessionNumber = ""
        if hasattr(f, "StudyID") is False:
            f.StudyID = ""
        return f

    def write_dicom_seg(
        self,
        mask_array: np.ndarray | sitk.Image,
        source_files: list[str],
        output_path: str,
        is_fractional: bool = False,
        is_fractional_compliant: bool = False,
        class_idx: int | None = None,
    ):
        """
        Writes a DICOM segmentation file.

        Args:
            mask_array (np.ndarray | sitk.Image): the mask array or sitk image.
            source_files (list[str]): the list of DICOM source files (as returned
                by ``nnunet_serve.utils.read_dicom_as_sitk``).
            output_path (str): the output path.
            is_fractional (bool, optional): whether the mask is fractional.
                Defaults to False.
            is_fractional_compliant (bool, optional): whether the probability mask
                should be converted to a map with ``N_FRACTIONAL`` labels,
                each corresponding to a percentage.
            class_idx (int | None, optional): index used for selecting specific
                segment descriptions when saving probability maps. Only used when
                either ``is_fractional==True`` or
                ``is_fractional_compliant==True``.
        """
        mask_array = self.to_array_if_necessary(mask_array)
        sorted_source_files = sort_dicom_slices(list(source_files))
        any_fractional = is_fractional or is_fractional_compliant
        if sorted_source_files != list(source_files):
            idx_map = {p: i for i, p in enumerate(source_files)}
            try:
                order = [idx_map[p] for p in sorted_source_files]
            except KeyError:
                order = None
            if order is not None:
                mask_array = mask_array[order, ...]
            source_files = sorted_source_files
        # adjust array size and segment descriptions to the strictly necessary
        labels = np.unique(mask_array)
        labels = labels[labels > 0]
        if len(labels) == 0:
            logger.warning("Mask is empty")
            return "empty"
        segment_descriptions = []
        if any_fractional is False:
            label_dict = {label: i + 1 for i, label in enumerate(labels)}
            label_dict[0] = 0
            mask_array = np.vectorize(label_dict.get)(mask_array)
            for i, label in enumerate(labels.astype(int)):
                seg_d = deepcopy(self.segment_descriptions[label - 1])
                seg_d.SegmentNumber = i + 1
                segment_descriptions.append(seg_d)
            if len(mask_array.shape) != 4:
                mask_array = one_hot_encode(
                    mask_array, len(segment_descriptions)
                )
        else:
            if class_idx is not None:
                if isinstance(class_idx, list):
                    if len(class_idx) == 1:
                        class_idx = class_idx[0]
                    else:
                        raise ValueError(
                            f"class_idx should be int or a list with length 1 (is {class_idx})"
                        )
                segment_descriptions = [
                    self.segment_descriptions[class_idx - 1]
                ]
            else:
                segment_descriptions = self.segment_descriptions
            if mask_array.ndim == 3:
                mask_array = mask_array[..., None]
        image_datasets = [
            self.make_compliant(hd.imread(str(f))) for f in source_files
        ]

        first = image_datasets[0]
        rows, cols = int(getattr(first, "Rows")), int(getattr(first, "Columns"))
        frames = len(image_datasets)
        if (frames, rows, cols) != mask_array.shape[:-1]:
            raise Exception(
                f"Mask shape {mask_array.shape} does not match image shape {frames}x{rows}x{cols}"
            )
        if (
            mask_array.shape[-1] != len(segment_descriptions)
            and any_fractional is False
        ):
            # this check makes no sense for fractional outputs!
            raise Exception(
                f"Mask shape {mask_array.shape} does not match number of segments {len(segment_descriptions)}"
            )

        # Create the Segmentation instance
        label_meanings = []
        for s in segment_descriptions:
            ss = s.SegmentedPropertyTypeCodeSequence[0]
            if isinstance(ss, pydicom.Dataset):
                meaning = ss.CodeMeaning
            else:
                meaning = ss.meaning
            label_meanings.append(meaning.lower().replace(" structure of", ""))
        seg_series_description = "Seg of " + ", ".join(label_meanings)
        if len(seg_series_description) > 64:
            seg_series_description = seg_series_description[:61] + "..."
        if is_fractional_compliant:
            if len(segment_descriptions) > 1:
                logger.warning(
                    "Skipping saving the fractional compliant mask because len(segment_descriptions) > 1"
                )
                return "skipped"
            seg_type = hd.seg.SegmentationTypeValues.BINARY
            logger.info("Converting mask to pseudo-fractional DICOM seg")
            sd = segment_descriptions[0]
            jet_cmap = colormaps.get("jet")
            percents = np.linspace(0, 1, N_FRACTIONAL + 1, endpoint=True)
            colours = jet_cmap(percents)
            label = sd.SegmentLabel
            if hasattr(sd, "SegmentDescription"):
                desc = sd.SegmentDescription
            else:
                desc = None
            new_segment_descriptions = []
            new_mask_array = []
            for i in range(N_FRACTIONAL):
                new_sd = deepcopy(sd)
                p1, p2 = percents[i], percents[i + 1]
                new_sd.SegmentNumber = i + 1
                new_label = f"{label} ({int(p1*100)}-{int(p2*100)}%)"
                new_sd.SegmentLabel = new_label
                if desc is not None:
                    new_desc = f"{desc} ({int(p1*100)}-{int(p2*100)}%)"
                    new_sd.SegmentDescription = new_desc
                curr_mask = (mask_array > p1) & (mask_array <= p2)
                new_sd.RecommendedDisplayCIELabValue = rgb_to_cielab(
                    colours[i, :3] * 255
                )
                new_segment_descriptions.append(new_sd)
                new_mask_array.append(curr_mask)
            segment_descriptions = new_segment_descriptions
            mask_array = np.concatenate(new_mask_array, axis=-1).astype(bool)
            logger.info("Converted mask to pseudo-fractional DICOM seg")
        elif is_fractional:
            seg_type = hd.seg.SegmentationTypeValues.FRACTIONAL
        else:
            seg_type = hd.seg.SegmentationTypeValues.BINARY
        seg_dataset = hd.seg.Segmentation(
            source_images=image_datasets,
            pixel_array=mask_array,
            segmentation_type=seg_type,
            segment_descriptions=segment_descriptions,
            series_instance_uid=hd.UID(),
            series_number=999,
            sop_instance_uid=hd.UID(),
            instance_number=1,
            manufacturer=self.manufacturer,
            manufacturer_model_name=self.manufacturer_model_name,
            software_versions=self.algorithm_version,
            device_serial_number="42",
            series_description=seg_series_description,
        )
        seg_dataset.ClinicalTrialSeriesID = self.clinical_trial_series_id
        seg_dataset.ClinicalTrialTimePointID = self.clinical_trial_time_point_id
        seg_dataset.BodyPartExamined = self.body_part_examined

        seg_dataset.save_as(output_path)

        return "success"

    def write_dicom_rtstruct(
        self,
        mask_array: np.ndarray | sitk.Image,
        source_files: list[str],
        output_path: str,
    ):
        """
        Routine to write a DICOM RTstruct object.

        Args:
            mask_array (np.ndarray | sitk.Image): mask image.
            source_files (list[str]): list of source files.
            output_path (str): output path.

        Returns:
            str: "success" if the operation was successful.
        """
        mask_array = self.to_array_if_necessary(mask_array)
        segment_info = [
            [s.SegmentLabel, list(random_color_generator())]
            for s in self.segment_descriptions
        ]
        save_mask_as_rtstruct(
            mask_array,
            Path(source_files[0]).parent,
            output_path,
            segment_info,
        )

        return "success"

    @staticmethod
    def init_from_dcmqi_metadata_file(
        metadata_file: str,
        algorithm_version: str = "v1.0",
        manufacturer: str = "Algorithm",
        manufacturer_model_name: str = "AlgorithmModel",
    ):
        """
        Uses a DCMQI metadata file generated using [1] to initialize a
        SegWriter.

        [1] https://qiicr.org/dcmqi/#/seg

        Args:
            metadata_file (str): path to the DCMQI metadata file.
            algorithm_version (str): algorithm version.
            manufacturer (str): manufacturer.
            manufacturer_model_name (str): manufacturer model name.

        Returns:
            SegWriter: initialized SegWriter.
        """
        metadata = from_dcmqi_metainfo(metadata_file)
        segments = list(metadata.SegmentSequence)
        algorithm_name = segments[0].SegmentAlgorithmName
        algorithm_type = hd.seg.SegmentAlgorithmTypeValues[
            segments[0].SegmentAlgorithmType
        ]
        series_description = metadata.SeriesDescription
        clinical_trial_series_id = metadata.ClinicalTrialSeriesID
        clinical_trial_time_point_id = metadata.ClinicalTrialTimePointID
        body_part_examined = metadata.BodyPartExamined
        return SegWriter(
            segment_descriptions=segments,
            algorithm_name=algorithm_name,
            algorithm_version=algorithm_version,
            algorithm_family=codes.CID7162.ArtificialIntelligence,
            algorithm_type=algorithm_type,
            instance_number=1,
            series_number=999,
            series_description=series_description,
            manufacturer=manufacturer,
            manufacturer_model_name=manufacturer_model_name,
            clinical_trial_series_id=clinical_trial_series_id,
            clinical_trial_time_point_id=clinical_trial_time_point_id,
            body_part_examined=body_part_examined,
        )

    @staticmethod
    def init_from_metadata_dict(
        metadata: dict[str, str], validate: bool = False
    ):
        """
        Initializes a SegWriter from a metadata dictionary. It automatically
        detects the type of metadata and calls the appropriate initialization
        method. If there is a "path" key, it is assumed to be a DCMQI metadata
        file and the init_from_dcmqi_metadata_file method is called. Otherwise,
        the constructor is called directly.

        Args:
            metadata (dict[str, str]): metadata dictionary.
            validate (bool): whether to validate the metadata.

        Returns:
            SegWriter: initialized SegWriter.
        """
        if "path" in metadata:
            return SegWriter.init_from_dcmqi_metadata_file(metadata["path"])
        else:
            return SegWriter(**metadata, validate=validate)

init_from_dcmqi_metadata_file(metadata_file, algorithm_version='v1.0', manufacturer='Algorithm', manufacturer_model_name='AlgorithmModel') staticmethod

Uses a DCMQI metadata file generated using [1] to initialize a SegWriter.

[1] https://qiicr.org/dcmqi/#/seg

Parameters:

Name Type Description Default
metadata_file str

path to the DCMQI metadata file.

required
algorithm_version str

algorithm version.

'v1.0'
manufacturer str

manufacturer.

'Algorithm'
manufacturer_model_name str

manufacturer model name.

'AlgorithmModel'

Returns:

Name Type Description
SegWriter

initialized SegWriter.

Source code in src/nnunet_serve/seg_writers.py
@staticmethod
def init_from_dcmqi_metadata_file(
    metadata_file: str,
    algorithm_version: str = "v1.0",
    manufacturer: str = "Algorithm",
    manufacturer_model_name: str = "AlgorithmModel",
):
    """
    Uses a DCMQI metadata file generated using [1] to initialize a
    SegWriter.

    [1] https://qiicr.org/dcmqi/#/seg

    Args:
        metadata_file (str): path to the DCMQI metadata file.
        algorithm_version (str): algorithm version.
        manufacturer (str): manufacturer.
        manufacturer_model_name (str): manufacturer model name.

    Returns:
        SegWriter: initialized SegWriter.
    """
    metadata = from_dcmqi_metainfo(metadata_file)
    segments = list(metadata.SegmentSequence)
    algorithm_name = segments[0].SegmentAlgorithmName
    algorithm_type = hd.seg.SegmentAlgorithmTypeValues[
        segments[0].SegmentAlgorithmType
    ]
    series_description = metadata.SeriesDescription
    clinical_trial_series_id = metadata.ClinicalTrialSeriesID
    clinical_trial_time_point_id = metadata.ClinicalTrialTimePointID
    body_part_examined = metadata.BodyPartExamined
    return SegWriter(
        segment_descriptions=segments,
        algorithm_name=algorithm_name,
        algorithm_version=algorithm_version,
        algorithm_family=codes.CID7162.ArtificialIntelligence,
        algorithm_type=algorithm_type,
        instance_number=1,
        series_number=999,
        series_description=series_description,
        manufacturer=manufacturer,
        manufacturer_model_name=manufacturer_model_name,
        clinical_trial_series_id=clinical_trial_series_id,
        clinical_trial_time_point_id=clinical_trial_time_point_id,
        body_part_examined=body_part_examined,
    )

init_from_metadata_dict(metadata, validate=False) staticmethod

Initializes a SegWriter from a metadata dictionary. It automatically detects the type of metadata and calls the appropriate initialization method. If there is a "path" key, it is assumed to be a DCMQI metadata file and the init_from_dcmqi_metadata_file method is called. Otherwise, the constructor is called directly.

Parameters:

Name Type Description Default
metadata dict[str, str]

metadata dictionary.

required
validate bool

whether to validate the metadata.

False

Returns:

Name Type Description
SegWriter

initialized SegWriter.

Source code in src/nnunet_serve/seg_writers.py
@staticmethod
def init_from_metadata_dict(
    metadata: dict[str, str], validate: bool = False
):
    """
    Initializes a SegWriter from a metadata dictionary. It automatically
    detects the type of metadata and calls the appropriate initialization
    method. If there is a "path" key, it is assumed to be a DCMQI metadata
    file and the init_from_dcmqi_metadata_file method is called. Otherwise,
    the constructor is called directly.

    Args:
        metadata (dict[str, str]): metadata dictionary.
        validate (bool): whether to validate the metadata.

    Returns:
        SegWriter: initialized SegWriter.
    """
    if "path" in metadata:
        return SegWriter.init_from_dcmqi_metadata_file(metadata["path"])
    else:
        return SegWriter(**metadata, validate=validate)

make_compliant(f)

Tries to smooth some potential missing fields.

Parameters:

Name Type Description Default
f Dataset

a pydicom Dataset.

required

Returns:

Type Description

A compliant pydicom.Dataset.

Source code in src/nnunet_serve/seg_writers.py
def make_compliant(self, f: pydicom.Dataset):
    """
    Tries to smooth some potential missing fields.

    Args:
        f (pydicom.Dataset): a pydicom Dataset.

    Returns:
        A compliant pydicom.Dataset.
    """
    if hasattr(f, "PatientSex") is False:
        f.PatientSex = "O"
    if hasattr(f, "AccessionNumber") is False:
        f.AccessionNumber = ""
    if hasattr(f, "StudyID") is False:
        f.StudyID = ""
    return f

to_array_if_necessary(mask)

Converts a mask to a numpy array if the mask is a sitk.Image.

Parameters:

Name Type Description Default
mask ndarray | Image

the mask to convert.

required

Returns:

Type Description
ndarray

np.ndarray: the mask as a numpy array.

Source code in src/nnunet_serve/seg_writers.py
def to_array_if_necessary(
    self, mask: np.ndarray | sitk.Image
) -> np.ndarray:
    """
    Converts a mask to a numpy array if the mask is a sitk.Image.

    Args:
        mask (np.ndarray | sitk.Image): the mask to convert.

    Returns:
        np.ndarray: the mask as a numpy array.
    """
    if isinstance(mask, sitk.Image):
        mask = sitk.GetArrayFromImage(mask)
    return mask

write_dicom_rtstruct(mask_array, source_files, output_path)

Routine to write a DICOM RTstruct object.

Parameters:

Name Type Description Default
mask_array ndarray | Image

mask image.

required
source_files list[str]

list of source files.

required
output_path str

output path.

required

Returns:

Name Type Description
str

"success" if the operation was successful.

Source code in src/nnunet_serve/seg_writers.py
def write_dicom_rtstruct(
    self,
    mask_array: np.ndarray | sitk.Image,
    source_files: list[str],
    output_path: str,
):
    """
    Routine to write a DICOM RTstruct object.

    Args:
        mask_array (np.ndarray | sitk.Image): mask image.
        source_files (list[str]): list of source files.
        output_path (str): output path.

    Returns:
        str: "success" if the operation was successful.
    """
    mask_array = self.to_array_if_necessary(mask_array)
    segment_info = [
        [s.SegmentLabel, list(random_color_generator())]
        for s in self.segment_descriptions
    ]
    save_mask_as_rtstruct(
        mask_array,
        Path(source_files[0]).parent,
        output_path,
        segment_info,
    )

    return "success"

write_dicom_seg(mask_array, source_files, output_path, is_fractional=False, is_fractional_compliant=False, class_idx=None)

Writes a DICOM segmentation file.

Parameters:

Name Type Description Default
mask_array ndarray | Image

the mask array or sitk image.

required
source_files list[str]

the list of DICOM source files (as returned by nnunet_serve.utils.read_dicom_as_sitk).

required
output_path str

the output path.

required
is_fractional bool

whether the mask is fractional. Defaults to False.

False
is_fractional_compliant bool

whether the probability mask should be converted to a map with N_FRACTIONAL labels, each corresponding to a percentage.

False
class_idx int | None

index used for selecting specific segment descriptions when saving probability maps. Only used when either is_fractional==True or is_fractional_compliant==True.

None
Source code in src/nnunet_serve/seg_writers.py
def write_dicom_seg(
    self,
    mask_array: np.ndarray | sitk.Image,
    source_files: list[str],
    output_path: str,
    is_fractional: bool = False,
    is_fractional_compliant: bool = False,
    class_idx: int | None = None,
):
    """
    Writes a DICOM segmentation file.

    Args:
        mask_array (np.ndarray | sitk.Image): the mask array or sitk image.
        source_files (list[str]): the list of DICOM source files (as returned
            by ``nnunet_serve.utils.read_dicom_as_sitk``).
        output_path (str): the output path.
        is_fractional (bool, optional): whether the mask is fractional.
            Defaults to False.
        is_fractional_compliant (bool, optional): whether the probability mask
            should be converted to a map with ``N_FRACTIONAL`` labels,
            each corresponding to a percentage.
        class_idx (int | None, optional): index used for selecting specific
            segment descriptions when saving probability maps. Only used when
            either ``is_fractional==True`` or
            ``is_fractional_compliant==True``.
    """
    mask_array = self.to_array_if_necessary(mask_array)
    sorted_source_files = sort_dicom_slices(list(source_files))
    any_fractional = is_fractional or is_fractional_compliant
    if sorted_source_files != list(source_files):
        idx_map = {p: i for i, p in enumerate(source_files)}
        try:
            order = [idx_map[p] for p in sorted_source_files]
        except KeyError:
            order = None
        if order is not None:
            mask_array = mask_array[order, ...]
        source_files = sorted_source_files
    # adjust array size and segment descriptions to the strictly necessary
    labels = np.unique(mask_array)
    labels = labels[labels > 0]
    if len(labels) == 0:
        logger.warning("Mask is empty")
        return "empty"
    segment_descriptions = []
    if any_fractional is False:
        label_dict = {label: i + 1 for i, label in enumerate(labels)}
        label_dict[0] = 0
        mask_array = np.vectorize(label_dict.get)(mask_array)
        for i, label in enumerate(labels.astype(int)):
            seg_d = deepcopy(self.segment_descriptions[label - 1])
            seg_d.SegmentNumber = i + 1
            segment_descriptions.append(seg_d)
        if len(mask_array.shape) != 4:
            mask_array = one_hot_encode(
                mask_array, len(segment_descriptions)
            )
    else:
        if class_idx is not None:
            if isinstance(class_idx, list):
                if len(class_idx) == 1:
                    class_idx = class_idx[0]
                else:
                    raise ValueError(
                        f"class_idx should be int or a list with length 1 (is {class_idx})"
                    )
            segment_descriptions = [
                self.segment_descriptions[class_idx - 1]
            ]
        else:
            segment_descriptions = self.segment_descriptions
        if mask_array.ndim == 3:
            mask_array = mask_array[..., None]
    image_datasets = [
        self.make_compliant(hd.imread(str(f))) for f in source_files
    ]

    first = image_datasets[0]
    rows, cols = int(getattr(first, "Rows")), int(getattr(first, "Columns"))
    frames = len(image_datasets)
    if (frames, rows, cols) != mask_array.shape[:-1]:
        raise Exception(
            f"Mask shape {mask_array.shape} does not match image shape {frames}x{rows}x{cols}"
        )
    if (
        mask_array.shape[-1] != len(segment_descriptions)
        and any_fractional is False
    ):
        # this check makes no sense for fractional outputs!
        raise Exception(
            f"Mask shape {mask_array.shape} does not match number of segments {len(segment_descriptions)}"
        )

    # Create the Segmentation instance
    label_meanings = []
    for s in segment_descriptions:
        ss = s.SegmentedPropertyTypeCodeSequence[0]
        if isinstance(ss, pydicom.Dataset):
            meaning = ss.CodeMeaning
        else:
            meaning = ss.meaning
        label_meanings.append(meaning.lower().replace(" structure of", ""))
    seg_series_description = "Seg of " + ", ".join(label_meanings)
    if len(seg_series_description) > 64:
        seg_series_description = seg_series_description[:61] + "..."
    if is_fractional_compliant:
        if len(segment_descriptions) > 1:
            logger.warning(
                "Skipping saving the fractional compliant mask because len(segment_descriptions) > 1"
            )
            return "skipped"
        seg_type = hd.seg.SegmentationTypeValues.BINARY
        logger.info("Converting mask to pseudo-fractional DICOM seg")
        sd = segment_descriptions[0]
        jet_cmap = colormaps.get("jet")
        percents = np.linspace(0, 1, N_FRACTIONAL + 1, endpoint=True)
        colours = jet_cmap(percents)
        label = sd.SegmentLabel
        if hasattr(sd, "SegmentDescription"):
            desc = sd.SegmentDescription
        else:
            desc = None
        new_segment_descriptions = []
        new_mask_array = []
        for i in range(N_FRACTIONAL):
            new_sd = deepcopy(sd)
            p1, p2 = percents[i], percents[i + 1]
            new_sd.SegmentNumber = i + 1
            new_label = f"{label} ({int(p1*100)}-{int(p2*100)}%)"
            new_sd.SegmentLabel = new_label
            if desc is not None:
                new_desc = f"{desc} ({int(p1*100)}-{int(p2*100)}%)"
                new_sd.SegmentDescription = new_desc
            curr_mask = (mask_array > p1) & (mask_array <= p2)
            new_sd.RecommendedDisplayCIELabValue = rgb_to_cielab(
                colours[i, :3] * 255
            )
            new_segment_descriptions.append(new_sd)
            new_mask_array.append(curr_mask)
        segment_descriptions = new_segment_descriptions
        mask_array = np.concatenate(new_mask_array, axis=-1).astype(bool)
        logger.info("Converted mask to pseudo-fractional DICOM seg")
    elif is_fractional:
        seg_type = hd.seg.SegmentationTypeValues.FRACTIONAL
    else:
        seg_type = hd.seg.SegmentationTypeValues.BINARY
    seg_dataset = hd.seg.Segmentation(
        source_images=image_datasets,
        pixel_array=mask_array,
        segmentation_type=seg_type,
        segment_descriptions=segment_descriptions,
        series_instance_uid=hd.UID(),
        series_number=999,
        sop_instance_uid=hd.UID(),
        instance_number=1,
        manufacturer=self.manufacturer,
        manufacturer_model_name=self.manufacturer_model_name,
        software_versions=self.algorithm_version,
        device_serial_number="42",
        series_description=seg_series_description,
    )
    seg_dataset.ClinicalTrialSeriesID = self.clinical_trial_series_id
    seg_dataset.ClinicalTrialTimePointID = self.clinical_trial_time_point_id
    seg_dataset.BodyPartExamined = self.body_part_examined

    seg_dataset.save_as(output_path)

    return "success"

close_match(a, b, ratio=0.8)

Returns True if the ratio of matching characters between strings a and b is greater than the specified ratio.

Parameters:

Name Type Description Default
a str

first string.

required
b str

second string.

required
ratio float

ratio of matching characters.

0.8

Returns:

Name Type Description
bool bool

True if the ratio of matching characters between strings a and b is greater than the specified ratio.

Source code in src/nnunet_serve/seg_writers.py
def close_match(a, b, ratio: float = 0.8) -> bool:
    """
    Returns True if the ratio of matching characters between strings a and b
    is greater than the specified ratio.

    Args:
        a (str): first string.
        b (str): second string.
        ratio (float): ratio of matching characters.

    Returns:
        bool: True if the ratio of matching characters between strings a and b
            is greater than the specified ratio.
    """
    return SequenceMatcher(None, a, b).ratio() > ratio

export_predictions(masks, output_dir, volumes=None, proba_maps=None, good_file_paths=None, suffix=None, is_dicom=False, seg_writers=None, save_proba_map=False, save_nifti_inputs=False, save_rt_struct_output=False, class_idx=None)

Export stage-wise nnUNet outputs to NIfTI and optional DICOM artifacts.

This function writes one output subdirectory per stage (stage_<i>) under output_dir and stores prediction masks for all stages. Depending on the flags, it can also export probability maps, input NIfTI volumes, DICOM SEG, DICOM RTSTRUCT, and DICOM fractional SEG outputs.

Parameters:

Name Type Description Default
masks list[Image]

Predicted segmentation masks, one sitk.Image per stage.

required
output_dir str

Base output directory.

required
volumes list[list[Image]] | None

Optional stage-wise input volumes used to export NIfTI inputs when save_nifti_inputs is enabled.

None
proba_maps list[list[Image]] | None

Optional stage-wise probability maps used when save_proba_map is enabled.

None
good_file_paths list[str] | None

Source DICOM file paths used to create DICOM outputs.

None
suffix str | None

Optional suffix appended to exported file names.

None
is_dicom bool

Whether to export DICOM-derived outputs (SEG/RTSTRUCT).

False
seg_writers SegWriter | list[SegWriter] | None

A SegWriter or list of stage-specific SegWriter instances used for DICOM exports.

None
save_proba_map bool

If True, exports probability maps in NIfTI and (for DICOM mode) fractional DICOM SEG format.

False
save_nifti_inputs bool

If True, exports input NIfTI volumes used per stage.

False
save_rt_struct_output bool

If True and is_dicom is enabled, exports RTSTRUCT in addition to DICOM SEG.

False
class_idx int | None

Optional class index metadata for probability export workflows.

None

Returns:

Name Type Description
dict

Mapping of artifact type to exported paths. Keys may include:

  • nifti_prediction
  • nifti_proba
  • nifti_inputs
  • dicom_segmentation
  • dicom_struct
  • dicom_fractional_segmentation
Source code in src/nnunet_serve/seg_writers.py
def export_predictions(
    masks: list[sitk.Image],
    output_dir: str,
    volumes: list[list[sitk.Image]] | None = None,
    proba_maps: list[list[sitk.Image]] | None = None,
    good_file_paths: list[str] | None = None,
    suffix: str | None = None,
    is_dicom: bool = False,
    seg_writers: SegWriter | list[SegWriter] | None = None,
    save_proba_map: bool = False,
    save_nifti_inputs: bool = False,
    save_rt_struct_output: bool = False,
    class_idx: int | None = None,
):
    """Export stage-wise nnUNet outputs to NIfTI and optional DICOM artifacts.

    This function writes one output subdirectory per stage (`stage_<i>`) under
    `output_dir` and stores prediction masks for all stages. Depending on the
    flags, it can also export probability maps, input NIfTI volumes, DICOM SEG,
    DICOM RTSTRUCT, and DICOM fractional SEG outputs.

    Args:
        masks: Predicted segmentation masks, one `sitk.Image` per stage.
        output_dir: Base output directory.
        volumes: Optional stage-wise input volumes used to export NIfTI inputs
            when `save_nifti_inputs` is enabled.
        proba_maps: Optional stage-wise probability maps used when
            `save_proba_map` is enabled.
        good_file_paths: Source DICOM file paths used to create DICOM outputs.
        suffix: Optional suffix appended to exported file names.
        is_dicom: Whether to export DICOM-derived outputs (SEG/RTSTRUCT).
        seg_writers: A `SegWriter` or list of stage-specific `SegWriter`
            instances used for DICOM exports.
        save_proba_map: If `True`, exports probability maps in NIfTI and (for
            DICOM mode) fractional DICOM SEG format.
        save_nifti_inputs: If `True`, exports input NIfTI volumes used per stage.
        save_rt_struct_output: If `True` and `is_dicom` is enabled, exports
            RTSTRUCT in addition to DICOM SEG.
        class_idx: Optional class index metadata for probability export workflows.

    Returns:
        dict: Mapping of artifact type to exported paths. Keys may include:
        - `nifti_prediction`
        - `nifti_proba`
        - `nifti_inputs`
        - `dicom_segmentation`
        - `dicom_struct`
        - `dicom_fractional_segmentation`
    """
    output_names = {
        "prediction": (
            "prediction" if suffix is None else f"prediction_{suffix}"
        ),
        "probabilities": (
            "probabilities" if suffix is None else f"proba_{suffix}"
        ),
        "struct": "struct" if suffix is None else f"struct_{suffix}",
    }

    output_paths = {}
    stage_dirs = [
        os.path.join(output_dir, f"stage_{i}") for i in range(len(masks))
    ]
    for stage_dir in stage_dirs:
        os.makedirs(stage_dir, exist_ok=True)

    mask_paths = []
    for i, mask in enumerate(masks):
        output_nifti_path = (
            f"{stage_dirs[i]}/{output_names['prediction']}.nii.gz"
        )
        sitk.WriteImage(mask, output_nifti_path)
        mask_paths.append(output_nifti_path)
        logger.info("Exported prediction mask %d to %s", i, output_nifti_path)
    output_paths["nifti_prediction"] = mask_paths

    if save_proba_map is True:
        proba_map_paths = []
        for i, proba_map in enumerate(proba_maps):
            output_nifti_path = (
                f"{stage_dirs[i]}/{output_names['probabilities']}.nii.gz"
            )
            if proba_map is None:
                logger.warning(f"proba_map for stage {i} is None, skipping")
                proba_map_paths.append(None)
                continue
            sitk.WriteImage(proba_map, output_nifti_path)
            proba_map_paths.append(output_nifti_path)
            logger.info(
                "Exported probability map %d to %s", i, output_nifti_path
            )
        output_paths["nifti_proba"] = proba_map_paths

    if save_nifti_inputs is True:
        niftis = []
        for i, volume_set in enumerate(volumes):
            for j, volume in enumerate(volume_set):
                output_nifti_path = os.path.join(
                    stage_dirs[i], f"input_volume_{j}.nii.gz"
                )
                sitk.WriteImage(volume, output_nifti_path)
                niftis.append(output_nifti_path)
                logger.info(
                    "Exported input volume %d to %s for stage %d",
                    j,
                    output_nifti_path,
                    i,
                )
        output_paths["nifti_inputs"] = niftis

    if is_dicom is True:
        dicom_seg_paths = []
        dicom_struct_paths = []
        for i, mask in enumerate(masks):
            dcm_seg_output_path = (
                f"{stage_dirs[i]}/{output_names['prediction']}.dcm"
            )
            status = seg_writers[i].write_dicom_seg(
                mask,
                source_files=good_file_paths[0],
                output_path=dcm_seg_output_path,
            )
            if "empty" in status:
                logger.info("Mask %d is empty, skipping DICOMseg/RTstruct", i)
                dicom_seg_paths.append(None)
                dicom_struct_paths.append(None)
            elif save_rt_struct_output:
                dcm_rts_output_path = (
                    f"{stage_dirs[i]}/{output_names['struct']}.dcm"
                )
                status = seg_writers[i].write_dicom_rtstruct(
                    mask,
                    source_files=good_file_paths[0],
                    output_path=dcm_rts_output_path,
                )
                dicom_seg_paths.append(dcm_seg_output_path)
                dicom_struct_paths.append(dcm_rts_output_path)
                logger.info(
                    "Exported DICOM struct %d to %s", i, dicom_struct_paths[-1]
                )
            else:
                dicom_seg_paths.append(dcm_seg_output_path)
                logger.info(
                    "Exported DICOM segmentation %d to %s",
                    i,
                    dicom_seg_paths[-1],
                )
        output_paths["dicom_segmentation"] = dicom_seg_paths
        if save_rt_struct_output:
            output_paths["dicom_struct"] = dicom_struct_paths

        if save_proba_map is True:
            dicom_proba_paths = []
            logger.info("Exporting probabilities")
            for i, proba_map in enumerate(proba_maps):
                if proba_map is None:
                    dicom_proba_paths.append(None)
                    continue
                output_path = (
                    f"{stage_dirs[i]}/{output_names['probabilities']}.dcm"
                )
                status = seg_writers[i].write_dicom_seg(
                    proba_map,
                    source_files=good_file_paths[0],
                    output_path=output_path,
                    is_fractional_compliant=True,
                    class_idx=class_idx[i],
                )
                if status == "empty":
                    dicom_proba_paths.append(None)
                    logger.info(
                        f"Mask {i} is empty, skipping DICOM probabilities."
                    )
                    continue
                elif status == "skipped":
                    dicom_proba_paths.append(None)
                    logger.info(f"Skipped saving the probabilistic DICOM")
                dicom_proba_paths.append(output_path)
                logger.info(
                    "Exported DICOM fractional segmentation %d to %s",
                    i,
                    dicom_proba_paths[-1],
                )
            output_paths["dicom_fractional_segmentation"] = dicom_proba_paths

    logger.info("Finished exporting predictions")
    return output_paths

get_empty_segment_description(algorithm_type, algorithm_identification, tracking_id)

Returns an empty segment description.

Parameters:

Name Type Description Default
algorithm_type str

algorithm type.

required
algorithm_identification str

algorithm identification.

required
tracking_id str

tracking ID.

required

Returns:

Type Description

hd.seg.SegmentDescription: empty segment description.

Source code in src/nnunet_serve/seg_writers.py
def get_empty_segment_description(
    algorithm_type, algorithm_identification, tracking_id: str
):
    """
    Returns an empty segment description.

    Args:
        algorithm_type (str): algorithm type.
        algorithm_identification (str): algorithm identification.
        tracking_id (str): tracking ID.

    Returns:
        hd.seg.SegmentDescription: empty segment description.
    """
    return hd.seg.SegmentDescription(
        segment_number=99,
        segment_label="Empty segment",
        segmented_property_category=codes.CID7150.PhysicalObject,
        segmented_property_type=codes.CID7150.PhysicalObject,
        algorithm_type=algorithm_type,
        algorithm_identification=algorithm_identification,
        tracking_uid=hd.UID(),
        tracking_id=tracking_id,
    )

get_segment_type_code(segment, i)

Resolve and build the DICOM coded concept for a segment.

Parameters:

Name Type Description Default
segment dict | str

Segment specification. If a dict, the following keys are supported:

  • name (str): Human-readable name of the structure (e.g., "Liver").
  • number (int, optional): 1-based segment number. Defaults to i + 1 at the call site.
  • label (str, optional): Display label for the segment. Defaults to the value of name.
  • tracking_id (str, optional): Stable identifier used for DICOM Tracking ID. Defaults to "Segment{number}_{label}".
  • code (str | int | None, optional): Coded value (Code Value) for the segmented property type. If None, it will be looked up automatically from the concept dictionary based on name.
  • scheme (str, optional): Coding Scheme Designator. Only "SCT" is supported for automatic lookup. Defaults to "SCT".
  • category_number (str | int | None, optional): Code Value for the segmented property category. If None, it will be derived downstream using CATEGORY_MAPPING using the output code.value property as key. Even when this is provided, it should represent a SNOMED CT category number.

If a string is provided, it is interpreted as name; the remaining fields are inferred using the defaults above. This segment string will be used to index pydicom.sr._concepts_dict.CONCEPTS, after being preprocessed to remove laterality-based indicators (e.g., "Right", "Left"; these are nonetheless included in the segment_dict["label"]).

required
i int

1-based segment number.

required

Returns:

Type Description
Code
  • Code: A pydicom.sr.codedict.Code representing the segment's type (value, meaning, scheme_designator).
Code
  • dict: The normalized segment dictionary with all defaults and the resolved code populated.

Raises:

Type Description
ValueError

If segment is neither dict nor str, if automatic lookup is requested with a non-"SCT" scheme, or if the concept cannot be found in the dictionary (closest matches are reported).

Source code in src/nnunet_serve/seg_writers.py
def get_segment_type_code(segment: dict | str, i: int) -> Code:
    """
    Resolve and build the DICOM coded concept for a segment.

    Args:
        segment (dict | str): Segment specification. If a dict, the following
            keys are supported:

            - name (str): Human-readable name of the structure (e.g., "Liver").
            - number (int, optional): 1-based segment number. Defaults to
              ``i + 1`` at the call site.
            - label (str, optional): Display label for the segment. Defaults to
              the value of ``name``.
            - tracking_id (str, optional): Stable identifier used for DICOM
              Tracking ID. Defaults to ``"Segment{number}_{label}"``.
            - code (str | int | None, optional): Coded value (Code Value) for
              the segmented property type. If ``None``, it will be looked up
              automatically from the concept dictionary based on ``name``.
            - scheme (str, optional): Coding Scheme Designator. Only ``"SCT"``
              is supported for automatic lookup. Defaults to ``"SCT"``.
            - category_number (str | int | None, optional): Code Value for the
              segmented property category. If ``None``, it will be derived
              downstream using ``CATEGORY_MAPPING`` using the output ``code.value``
              property as key. Even when this is provided, it should represent
              a SNOMED CT category number.

            If a string is provided, it is interpreted as ``name``; the
            remaining fields are inferred using the defaults above. This
            segment string will be used to index
            ``pydicom.sr._concepts_dict.CONCEPTS``, after being preprocessed
            to remove laterality-based indicators (e.g., "Right", "Left"; these
            are nonetheless included in the ``segment_dict["label"]``).
        i (int): 1-based segment number.

    Returns:
        - Code: A ``pydicom.sr.codedict.Code`` representing the segment's
            type (value, meaning, scheme_designator).
        - dict: The normalized segment dictionary with all defaults and the
            resolved code populated.

    Raises:
        ValueError: If ``segment`` is neither ``dict`` nor ``str``, if
            automatic lookup is requested with a non-"SCT" scheme, or if the
            concept cannot be found in the dictionary (closest matches are
            reported).
    """
    segment_dict = {}
    if isinstance(segment, dict):
        segment_dict["name"] = segment["name"]
        segment_dict["number"] = segment.get("number", i + 1)
        segment_dict["label"] = segment.get("label", segment["name"])
        segment_dict["tracking_id"] = segment.get(
            "tracking_id",
            f"Segment{segment_dict['number']}_{segment_dict['label']}",
        )
        segment_dict["code"] = segment.get("code", None)
        segment_dict["scheme"] = segment.get("scheme", "SCT")
        segment_dict["category_number"] = segment.get("category_number", None)
    elif isinstance(segment, str):
        segment_dict["name"] = segment
        segment_dict["number"] = i + 1
        segment_dict["label"] = segment
        segment_dict[
            "tracking_id"
        ] = f"Segment{segment_dict['number']}_{segment_dict['label']}"
        segment_dict["code"] = None
        segment_dict["scheme"] = DEFAULT_SEGMENT_SCHEME
        segment_dict["category_number"] = None
    else:
        raise ValueError(f"Invalid segment: {segment}")
    if segment_dict["code"] is None:
        name = to_camel_case(process_name(segment_dict["name"]))
        if segment_dict["scheme"] not in ["SCT", "EUCAIM"]:
            raise ValueError(
                "Only SCT and EUCAIM schemes are supported for automatic retrieval"
            )
        if name not in NATURAL_LANGUAGE_TO_CODE[segment_dict["scheme"]]:
            closest_matches = []
            for k in NATURAL_LANGUAGE_TO_CODE[segment_dict["scheme"]]:
                if close_match(k, name, 0.8):
                    closest_matches.append(k)
            raise ValueError(
                f"Segment {name} not found in {segment_dict['scheme']}. Closest matches: {closest_matches}"
            )
        code_info = NATURAL_LANGUAGE_TO_CODE[segment_dict["scheme"]][name]
        segment_dict["name"] = code_info[0]
        segment_dict["code"] = code_info[1]
    segment_code = Code(
        value=segment_dict["code"],
        meaning=strip_laterality(to_camel_case(segment_dict["name"], " ")),
        scheme_designator=segment_dict["scheme"],
    )

    if segment_dict["category_number"] is None:
        segment_dict["category_number"] = CATEGORY_MAPPING[
            segment_dict["scheme"]
        ]["type"][str(segment_code.value)]

    laterality = get_laterality(segment_dict["label"])
    if not laterality:
        laterality = get_laterality(segment_dict["name"])
    segment_dict["laterality"] = laterality
    if laterality:
        if str(segment_dict["category_number"]) == "49755003":
            pass
        elif laterality.lower() not in segment_dict["label"].lower():
            segment_dict["label"] = f'{laterality} {segment_dict["label"]}'
    return segment_code, segment_dict

one_hot_encode(arr, n_labels)

Converts a numpy array to a one-hot encoded numpy array.

Parameters:

Name Type Description Default
arr ndarray

numpy array to be converted.

required
n_labels int

Number of labels for encoding.

required

Returns:

Type Description
ndarray

np.ndarray: one-hot encoded numpy array of shape (*arr.shape, n_labels).

Source code in src/nnunet_serve/seg_writers.py
def one_hot_encode(arr: np.ndarray, n_labels: int) -> np.ndarray:
    """
    Converts a numpy array to a one-hot encoded numpy array.

    Args:
        arr (np.ndarray): numpy array to be converted.
        n_labels (int): Number of labels for encoding.

    Returns:
        np.ndarray: one-hot encoded numpy array of shape
            (*arr.shape, n_labels).
    """
    output_arr = np.zeros([*arr.shape, n_labels])
    for i in range(1, n_labels + 1):
        output_arr[..., i - 1] = arr == i
    output_arr = output_arr
    return output_arr

process_name(name)

Processes a name by stripping laterality and extra spaces.

Parameters:

Name Type Description Default
name str

The name to process.

required

Returns:

Name Type Description
str str

The processed name.

Source code in src/nnunet_serve/seg_writers.py
def process_name(name: str) -> str:
    """
    Processes a name by stripping laterality and extra spaces.

    Args:
        name (str): The name to process.

    Returns:
        str: The processed name.
    """
    name = strip_laterality(name)
    name = re.sub("[ _]+", "", name)
    name = name.strip()
    return name

random_color_generator()

Returns a random color as a tuple of RGB values.

Returns:

Name Type Description
tuple

tuple of RGB values.

Source code in src/nnunet_serve/seg_writers.py
def random_color_generator():
    """
    Returns a random color as a tuple of RGB values.

    Returns:
        tuple: tuple of RGB values.
    """
    r = random.randint(0, 255)
    g = random.randint(0, 255)
    b = random.randint(0, 255)
    return (r, g, b)

save_mask_as_rtstruct(img_data, dcm_reference_file, output_path, segment_info)

Converts a numpy array to an RT (radiotherapy) struct object. Could be a multi-class object (each n > 0 corresponds to a class). The number of classes corresponds to np.unique(img_data).shape[0] - 1.

Parameters:

Name Type Description Default
img_data ndarray

numpy array with n non-zero unique values, each of which corresponds to a class.

required
dcm_reference_file str

reference DICOM files.

required
output_path str

output file for RT struct file.

required
segment_info tuple[str, list[int]]

segment information. Should be a list with size equal to the number of classes, and each element should be a tuple whose first element is the segment description and the second element a list of RGB values.

required
Source code in src/nnunet_serve/seg_writers.py
def save_mask_as_rtstruct(
    img_data: np.ndarray,
    dcm_reference_file: str,
    output_path: str,
    segment_info: list[tuple[str, list[int]]],
) -> None:
    """
    Converts a numpy array to an RT (radiotherapy) struct object. Could be a
        multi-class object (each n > 0 corresponds to a class). The number of
        classes corresponds to ``np.unique(img_data).shape[0] - 1``.

    Args:
        img_data (np.ndarray): numpy array with n non-zero unique values, each
            of which corresponds to a class.
        dcm_reference_file (str): reference DICOM files.
        output_path (str): output file for RT struct file.
        segment_info (tuple[str, list[int]]): segment information. Should be a
            list with size equal to the number of classes, and each element
            should be a tuple whose first element is the segment description
            and the second element a list of RGB values.
    """
    try:
        from rt_utils import RTStructBuilder
    except ImportError:
        raise ImportError(
            "rt_utils is required to save masks in RT struct format"
        )
    # based on the TotalSegmentator implementation

    logging.basicConfig(level=logging.WARNING)  # avoid messages from rt_utils

    # create new RT Struct - requires original DICOM
    rtstruct = RTStructBuilder.create_new(dicom_series_path=dcm_reference_file)

    # retrieve selected classes
    img_data = img_data.swapaxes(0, 2)
    selected_classes = np.unique(img_data)
    selected_classes = selected_classes[selected_classes > 0].tolist()
    if len(selected_classes) == 0:
        return None

    # add mask to RT Struct
    for class_idx in tqdm(selected_classes):
        class_name, class_colour = segment_info[class_idx - 1]
        binary_img = img_data == class_idx
        if binary_img.sum() > 0:  # only save none-empty images
            # add segmentation to RT Struct
            rtstruct.add_roi(
                mask=binary_img,  # has to be a binary numpy array
                name=class_name,
                color=class_colour,
            )

    rtstruct.save(str(output_path))

strip_laterality(name)

Strips laterality indicators (left/right) from a string.

Parameters:

Name Type Description Default
name str

The string to process.

required

Returns:

Name Type Description
str str

The string without laterality indicators.

Source code in src/nnunet_serve/seg_writers.py
def strip_laterality(name: str) -> str:
    """
    Strips laterality indicators (left/right) from a string.

    Args:
        name (str): The string to process.

    Returns:
        str: The string without laterality indicators.
    """

    name = re.sub("[ _]*[lL]eft[ _]*", "", name)
    name = re.sub("[ _]*[rR]ight[ _]*", "", name)
    name = name.strip()
    return name