import os
import glob
import collections
import sct_dbtool
from sct_dbtool import api
from tqdm import tqdm
from jinja2 import Environment, PackageLoader
import nibabel
from typing import Type, Generic, List, Dict, Set
[docs]def class_inheritors(klass: Type) -> Set[Type]:
"""This method will search for all classes inheriting
from the SanityCheck class.
:param klass: the class to search for childs.
:return: list of the classes.
"""
subclasses = set()
work = [klass]
while work:
parent = work.pop()
for child in parent.__subclasses__():
if child not in subclasses:
subclasses.add(child)
work.append(child)
return subclasses
[docs]class Diagnostic(object):
"""This is a data class to represent the diagnostics.
:param category: the diagnostic category (warning, error, info)
:param check_name: the name of the check that generated the diagnostic.
:param message: message generated by the check.
:param item: the item related to the diagnostic.
"""
def __init__(self, category: str, check_name: str,
message: str, item: Dict):
self.category = category
self.check_name = check_name
self.message = message
self.item = item
[docs]class SanityCheck(object):
"""Extend this class to implement new sanity checks.
:param category: category of the check (warning, error or info).
:param db_path: the path to the database.
"""
CATEGORIES = ["warning", "error", "info"]
def __init__(self, category: str, db_path: str):
if category not in SanityCheck.CATEGORIES:
raise RuntimeError("Invalid category name.")
self.category = category
self.name = type(self).__name__
self.diagnostics = []
self.db_path = db_path
[docs] def add_diagnostic(self, msg: str, item: Dict):
"""Add a new diagnostic message for the item.
:param msg: message with details about the diagnostic.
:param item: item who generated the diagnostic.
"""
diag = Diagnostic(self.category, self.name,
msg, item)
self.diagnostics.append(diag)
[docs] def check(self, item: Dict):
"""Implement this function in the sub-classes to add
the specific checks.
:param item: each item of the database
"""
raise NotImplementedError
[docs] def teardown(self):
"""This function is called after finishing the checking
on all database items. Before collecting diagnostic
messages."""
pass
[docs] def collect_diagnostics(self) -> List[Diagnostic]:
"""This method will collect all diagnostics generated
by each checking.
:return: a list of diagnostics.
"""
diags = self.diagnostics[:]
self.diagnostics.clear()
return diags
##############################################################################
# Implement more checks here
##############################################################################
[docs]class CheckEmptyDemographics(SanityCheck):
"""This check will verify if the demographics field is present."""
def __init__(self, db_path: str):
super().__init__("warning", db_path)
[docs] def check(self, item: Dict):
if "demographic" not in item:
self.add_diagnostic("Demographic field not present.", item)
return
if item["demographic"] is None:
self.add_diagnostic("Demographic field is empty.", item)
[docs]class CheckMainVolumePath(SanityCheck):
"""This check will verify if the main volume exists in the file system."""
def __init__(self, db_path: str):
super().__init__("error", db_path)
[docs] def check(self, item: Dict):
for image in item["images"]:
filename_path = os.path.join(self.db_path, image["filename"])
if not os.path.exists(filename_path):
self.add_diagnostic("File {} doesn't exists.".format(image["filename"]),
item)
[docs]class CheckLabelVolumePath(SanityCheck):
"""This check will verify if the label volume exists in the file system."""
def __init__(self, db_path: str):
super().__init__("error", db_path)
[docs] def check(self, item: Dict):
for image in item["images"]:
for labeled_image in image["labeled_images"]:
filename_path = os.path.join(self.db_path, labeled_image["filename"])
if not os.path.exists(filename_path):
self.add_diagnostic("File {} doesn't exists.".format(labeled_image["filename"]),
item)
[docs]class CheckLabelVolumeConsistency(SanityCheck):
"""This check will verify if the label is consistent with the volume (shapes)."""
def __init__(self, db_path: str):
super().__init__("error", db_path)
[docs] def check(self, item: Dict):
for image in item["images"]:
image_path = os.path.join(self.db_path, image["filename"])
if not os.path.exists(image_path):
continue
for labeled_image in image["labeled_images"]:
label_path = os.path.join(self.db_path, labeled_image["filename"])
if not os.path.exists(label_path):
continue
try:
nib_volume = nibabel.load(image_path)
nib_label = nibabel.load(label_path)
# Just check the first three dimensions
nib_volume_shape = nib_volume.get_shape()[0:3]
nib_label_shape = nib_label.get_shape()[0:3]
if nib_volume_shape != nib_label_shape:
msg = "Label {} has shape {}, while {} has shape {}.".format(
labeled_image["filename"], nib_volume_shape,
image["filename"], nib_label_shape
)
self.add_diagnostic(msg, item)
except:
pass
[docs]class CheckMainVolumeNibabel(SanityCheck):
"""This check will verify if nibabel is able to open the main volume."""
def __init__(self, db_path: str):
super().__init__("warning", db_path)
[docs] def check(self, item: Dict):
for image in item["images"]:
filename_path = os.path.join(self.db_path, image["filename"])
if not os.path.exists(filename_path):
continue
try:
volume = nibabel.load(filename_path)
if volume is None:
self.add_diagnostic("Nibabel couldn't open the file {} (None return)."
.format(image["filename"]), item)
except Exception as exception:
self.add_diagnostic("Nibabel couldn't open the file {} (exception {})."
.format(image["filename"], type(exception).__name__),
item)
[docs]class CheckFilesystemDB(SanityCheck):
"""This check will verify if all files in file system are present in the
database."""
def __init__(self, db_path: str):
super().__init__("info", db_path)
nii_pattern = os.path.join(db_path, "**/*.nii")
niigz_pattern = os.path.join(db_path, "**/*.nii.gz")
all_files = glob.glob(nii_pattern, recursive=True)
all_files += glob.glob(niigz_pattern, recursive=True)
all_rel_files = []
for filename in all_files:
rel_filename = os.path.relpath(filename, db_path)
all_rel_files.append(rel_filename)
self.all_files = all_rel_files
self.db_files = []
[docs] def check(self, item: Dict):
for image in item["images"]:
image_filename = image["filename"]
image_path = os.path.join(self.db_path, image_filename)
if not os.path.exists(image_path):
continue
self.db_files.append(image_filename)
for labeled_image in image["labeled_images"]:
labeled_filename = labeled_image["filename"]
label_path = os.path.join(self.db_path, labeled_image["filename"])
if not os.path.exists(label_path):
continue
self.db_files.append(labeled_filename)
[docs] def teardown(self):
fs_files = set(self.all_files)
db_files = set(self.db_files)
missing_files = fs_files - db_files
for f in missing_files:
msg = "File {} isn't mentioned in the database.".format(f)
item = {
'id': -1,
'study': 'unknown',
'center': 'unknown',
}
self.add_diagnostic(msg, item)
[docs]def sanity_check(api_client: api.APIClient, arguments: Dict):
"""Main entry-point for the sanity checking.
:param api_client: a configured instance of APIClient.
:param arguments: the CLI arguments.
"""
db_path = arguments["<db_path>"]
if not os.path.isdir(db_path):
raise RuntimeError("Database path doesn't exists !")
print("Retrieving dataset from [{api.hostname}:{api.port}]..."
.format(api=api_client))
dataset = api_client.get_dataset()
print("{} items retrieved.".format(len(dataset)))
# 1. Initialization
print("Initializing checks...")
sanity_checks = class_inheritors(SanityCheck)
sanity_checks_name = [check.__name__ for check in sanity_checks]
sanity_checks = [check(db_path) for check in sanity_checks]
print("{} sanity checks loaded.".format(len(sanity_checks)))
# 2. Checking
for item in tqdm(dataset, desc="Processing items"):
for check in sanity_checks:
check.check(item)
# 3. Teardown
for check in sanity_checks:
check.teardown()
# 4. Collect messages
diagnostics = []
for check in sanity_checks:
diags = check.collect_diagnostics()
diagnostics += diags
print("{} diagnostics found.".format(len(diagnostics)))
category_counter = collections.Counter()
for diagnostic in diagnostics:
category_counter[diagnostic.category] += 1
env = Environment(loader=PackageLoader('sct_dbtool',
'templates'))
template = env.get_template('report.html')
template = template.stream(category_counter=category_counter,
diagnostics=diagnostics,
api=api_client,
version=sct_dbtool.__version__,
sanity_checks_name=sanity_checks_name)
template.dump("report.html")
print("Results were written into 'report.html' file.")