Source code for sct_dbtool.sanity

import os
import glob
import collections

import sct_dbtool
from sct_dbtool import api

from tqdm import tqdm
from jinja2 import Environment, PackageLoader
import nibabel

from typing import Type, Generic, List, Dict, Set


[docs]def class_inheritors(klass: Type) -> Set[Type]: """This method will search for all classes inheriting from the SanityCheck class. :param klass: the class to search for childs. :return: list of the classes. """ subclasses = set() work = [klass] while work: parent = work.pop() for child in parent.__subclasses__(): if child not in subclasses: subclasses.add(child) work.append(child) return subclasses
[docs]class Diagnostic(object): """This is a data class to represent the diagnostics. :param category: the diagnostic category (warning, error, info) :param check_name: the name of the check that generated the diagnostic. :param message: message generated by the check. :param item: the item related to the diagnostic. """ def __init__(self, category: str, check_name: str, message: str, item: Dict): self.category = category self.check_name = check_name self.message = message self.item = item
[docs]class SanityCheck(object): """Extend this class to implement new sanity checks. :param category: category of the check (warning, error or info). :param db_path: the path to the database. """ CATEGORIES = ["warning", "error", "info"] def __init__(self, category: str, db_path: str): if category not in SanityCheck.CATEGORIES: raise RuntimeError("Invalid category name.") self.category = category self.name = type(self).__name__ self.diagnostics = [] self.db_path = db_path
[docs] def add_diagnostic(self, msg: str, item: Dict): """Add a new diagnostic message for the item. :param msg: message with details about the diagnostic. :param item: item who generated the diagnostic. """ diag = Diagnostic(self.category, self.name, msg, item) self.diagnostics.append(diag)
[docs] def check(self, item: Dict): """Implement this function in the sub-classes to add the specific checks. :param item: each item of the database """ raise NotImplementedError
[docs] def teardown(self): """This function is called after finishing the checking on all database items. Before collecting diagnostic messages.""" pass
[docs] def collect_diagnostics(self) -> List[Diagnostic]: """This method will collect all diagnostics generated by each checking. :return: a list of diagnostics. """ diags = self.diagnostics[:] self.diagnostics.clear() return diags
############################################################################## # Implement more checks here ##############################################################################
[docs]class CheckEmptyDemographics(SanityCheck): """This check will verify if the demographics field is present.""" def __init__(self, db_path: str): super().__init__("warning", db_path)
[docs] def check(self, item: Dict): if "demographic" not in item: self.add_diagnostic("Demographic field not present.", item) return if item["demographic"] is None: self.add_diagnostic("Demographic field is empty.", item)
[docs]class CheckMainVolumePath(SanityCheck): """This check will verify if the main volume exists in the file system.""" def __init__(self, db_path: str): super().__init__("error", db_path)
[docs] def check(self, item: Dict): for image in item["images"]: filename_path = os.path.join(self.db_path, image["filename"]) if not os.path.exists(filename_path): self.add_diagnostic("File {} doesn't exists.".format(image["filename"]), item)
[docs]class CheckLabelVolumePath(SanityCheck): """This check will verify if the label volume exists in the file system.""" def __init__(self, db_path: str): super().__init__("error", db_path)
[docs] def check(self, item: Dict): for image in item["images"]: for labeled_image in image["labeled_images"]: filename_path = os.path.join(self.db_path, labeled_image["filename"]) if not os.path.exists(filename_path): self.add_diagnostic("File {} doesn't exists.".format(labeled_image["filename"]), item)
[docs]class CheckLabelVolumeConsistency(SanityCheck): """This check will verify if the label is consistent with the volume (shapes).""" def __init__(self, db_path: str): super().__init__("error", db_path)
[docs] def check(self, item: Dict): for image in item["images"]: image_path = os.path.join(self.db_path, image["filename"]) if not os.path.exists(image_path): continue for labeled_image in image["labeled_images"]: label_path = os.path.join(self.db_path, labeled_image["filename"]) if not os.path.exists(label_path): continue try: nib_volume = nibabel.load(image_path) nib_label = nibabel.load(label_path) # Just check the first three dimensions nib_volume_shape = nib_volume.get_shape()[0:3] nib_label_shape = nib_label.get_shape()[0:3] if nib_volume_shape != nib_label_shape: msg = "Label {} has shape {}, while {} has shape {}.".format( labeled_image["filename"], nib_volume_shape, image["filename"], nib_label_shape ) self.add_diagnostic(msg, item) except: pass
[docs]class CheckMainVolumeNibabel(SanityCheck): """This check will verify if nibabel is able to open the main volume.""" def __init__(self, db_path: str): super().__init__("warning", db_path)
[docs] def check(self, item: Dict): for image in item["images"]: filename_path = os.path.join(self.db_path, image["filename"]) if not os.path.exists(filename_path): continue try: volume = nibabel.load(filename_path) if volume is None: self.add_diagnostic("Nibabel couldn't open the file {} (None return)." .format(image["filename"]), item) except Exception as exception: self.add_diagnostic("Nibabel couldn't open the file {} (exception {})." .format(image["filename"], type(exception).__name__), item)
[docs]class CheckFilesystemDB(SanityCheck): """This check will verify if all files in file system are present in the database.""" def __init__(self, db_path: str): super().__init__("info", db_path) nii_pattern = os.path.join(db_path, "**/*.nii") niigz_pattern = os.path.join(db_path, "**/*.nii.gz") all_files = glob.glob(nii_pattern, recursive=True) all_files += glob.glob(niigz_pattern, recursive=True) all_rel_files = [] for filename in all_files: rel_filename = os.path.relpath(filename, db_path) all_rel_files.append(rel_filename) self.all_files = all_rel_files self.db_files = []
[docs] def check(self, item: Dict): for image in item["images"]: image_filename = image["filename"] image_path = os.path.join(self.db_path, image_filename) if not os.path.exists(image_path): continue self.db_files.append(image_filename) for labeled_image in image["labeled_images"]: labeled_filename = labeled_image["filename"] label_path = os.path.join(self.db_path, labeled_image["filename"]) if not os.path.exists(label_path): continue self.db_files.append(labeled_filename)
[docs] def teardown(self): fs_files = set(self.all_files) db_files = set(self.db_files) missing_files = fs_files - db_files for f in missing_files: msg = "File {} isn't mentioned in the database.".format(f) item = { 'id': -1, 'study': 'unknown', 'center': 'unknown', } self.add_diagnostic(msg, item)
[docs]def sanity_check(api_client: api.APIClient, arguments: Dict): """Main entry-point for the sanity checking. :param api_client: a configured instance of APIClient. :param arguments: the CLI arguments. """ db_path = arguments["<db_path>"] if not os.path.isdir(db_path): raise RuntimeError("Database path doesn't exists !") print("Retrieving dataset from [{api.hostname}:{api.port}]..." .format(api=api_client)) dataset = api_client.get_dataset() print("{} items retrieved.".format(len(dataset))) # 1. Initialization print("Initializing checks...") sanity_checks = class_inheritors(SanityCheck) sanity_checks_name = [check.__name__ for check in sanity_checks] sanity_checks = [check(db_path) for check in sanity_checks] print("{} sanity checks loaded.".format(len(sanity_checks))) # 2. Checking for item in tqdm(dataset, desc="Processing items"): for check in sanity_checks: check.check(item) # 3. Teardown for check in sanity_checks: check.teardown() # 4. Collect messages diagnostics = [] for check in sanity_checks: diags = check.collect_diagnostics() diagnostics += diags print("{} diagnostics found.".format(len(diagnostics))) category_counter = collections.Counter() for diagnostic in diagnostics: category_counter[diagnostic.category] += 1 env = Environment(loader=PackageLoader('sct_dbtool', 'templates')) template = env.get_template('report.html') template = template.stream(category_counter=category_counter, diagnostics=diagnostics, api=api_client, version=sct_dbtool.__version__, sanity_checks_name=sanity_checks_name) template.dump("report.html") print("Results were written into 'report.html' file.")