...
 
Commits (11)
labellabelfrom 0.0 1.0 2.0 3.0 4.0 5.0 12.0 16.0 24.0 31.0 32.0 33.0 35.0 36.0 37.0 38.0 39.0 40.0 41.0 42.0 43.0 44.0 47.0 48.0 49.0 50.0 51.0 52.0 53.0 56.0 57.0 58.0 59.0 60.0 61.0 62.0 63.0 64.0 65.0 66.0 67.0 70.0 72.0 73.0 74.0 76.0 77.0 81.0 82.0 83.0 84.0 85.0 86.0 87.0 89.0 90.0 91.0 92.0 93.0 94.0 96.0 97.0 101.0 102.0 103.0 104.0 105.0 106.0 107.0 108.0 109.0 110.0 113.0 114.0 115.0 116.0 117.0 118.0 119.0 120.0 121.0 122.0 123.0 124.0 125.0 126.0 129.0 130.0 133.0 134.0 135.0 136.0 137.0 138.0 139.0 140.0 141.0 142.0 143.0 144.0 145.0 146.0 147.0 148.0 149.0 150.0 151.0 152.0 153.0 154.0 155.0 156.0 157.0 158.0 161.0 162.0 163.0 164.0 165.0 166.0 167.0 168.0 169.0 170.0 171.0 172.0 173.0 174.0 175.0 176.0 177.0 178.0 179.0 180.0 181.0 182.0 183.0 184.0 185.0 186.0 187.0 188.0 191.0 192.0 193.0 194.0 195.0 196.0 197.0 198.0 199.0 200.0 201.0 202.0 203.0 204.0 205.0 206.0 207.0 208.0
labellabelto 0.0 0.0 0.0 0.0 1.0 1.0 1.0 1.0 3.0 3.0 2.0 2.0 0.0 0.0 3.0 3.0 2.0 2.0 3.0 3.0 1.0 1.0 1.0 2.0 2.0 1.0 1.0 1.0 1.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 0.0 0.0 0.0 0.0 0.0 2.0 2.0 2.0 0.0 0.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 3.0 0.0 0.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0 2.0
\ No newline at end of file
......@@ -57,6 +57,7 @@ class MultitaskPIMMS3D(BaseNet):
#### Do this conditionally? #####
scope.reuse_variables()
out = modality_classifier(tf.expand_dims(input_tensor[..., z_size//2, i], -1), True)
out = tf.check_numerics(out, message='Modality classifier outputs NaNs')
modality_scores.append(out)
modality_tensor = tf.expand_dims(tf.expand_dims(tf.expand_dims(tf.stack(modality_scores, axis=-1), axis=2), axis=2), axis=2)
......@@ -98,7 +99,7 @@ class MultitaskPIMMS3D(BaseNet):
w_regularizer=self.regularizers['w'])
brain_parcellation_tensor = brain_parcellation_op(abstraction_tensor, is_training)
tf.logging.info('Brain Parcellation frontend output dims: %s' % brain_parcellation_tensor.shape)
classification_tensor = tf.reshape(tf.transpose(tf.stack(modality_scores, axis=-1), [0, 2, 1]), shape=[n_subj_in_batch*n_ims_per_subj, n_modalities])
classification_tensor = tf.reshape(tf.transpose(tf.stack(modality_scores, axis=-1), [0, 2, 1]), shape=[n_subj_in_batch, n_ims_per_subj, n_modalities])
tf.logging.info('Classification tensor output dims: %s' % classification_tensor.shape)
return segmentation_tensor, brain_parcellation_tensor, classification_tensor
......@@ -299,7 +300,7 @@ class HighRes3dFrontendBlock(BaseNet):
fc_layer = ConvolutionalLayer(
n_output_chns=params['n_features'],
kernel_size=params['kernel_size'],
acti_func='softmax',
acti_func=None,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
name=params['name'])
......
......@@ -72,6 +72,7 @@ class ResNet(BaseNet):
layers = self.create()
out = layers.conv1(images, is_training)
for block in layers.blocks:
out = tf.check_numerics(out, message='NaNs in the blocks')
out = block(out, is_training)
out = tf.expand_dims(tf.reduce_mean(tf.nn.relu(layers.bn(out, is_training)), axis=[1, 2, 3]), axis=[-1])
tf.logging.info('{} shape: {}'.format(out.name, out.shape))
......
This diff is collapsed.
......@@ -93,6 +93,8 @@ SUPPORTED_LOSS_SEGMENTATION = {
'niftynet.layer.loss_segmentation.tversky',
"GDSC":
'niftynet.layer.loss_segmentation.generalised_dice_loss',
"DicePlusXEnt":
'niftynet.layer.loss_segmentation.dice_plus_xent_loss',
"WGDL":
'niftynet.layer.loss_segmentation.generalised_wasserstein_dice_loss',
"SensSpec":
......
......@@ -182,6 +182,9 @@ class OutputsCollector(object):
This function should be called in
`ApplicationDriver.create_graph` function
"""
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for var in all_vars:
tf.summary.histogram(var.name, var)
self._average_variables_over_devices(self.console_vars, False)
self._average_variables_over_devices(self.output_vars, False)
self._average_variables_over_devices(self.summary_vars, True)
......
......@@ -60,12 +60,7 @@ class UniformSampler(ImageWindowDataset):
``{image_modality: data_array, image_location: n_samples * 7}``
"""
image_id, data, _ = self.reader(idx=idx, shuffle=True)
print('range of values for BinLesion', np.min(data['label'][..., 0]), np.max(data['label'][..., 0]))
print('range of values for Parcellation', np.min(data['label'][..., 1]), np.max(data['label'][..., 1]))
print('unique values for Parcellation: \n')
print(np.unique(data['label'][..., 1]))
##### Randomly drop modalities according to params #####
tf.logging.info('Image_id %s' % image_id)
modalities_to_drop = int(np.random.choice([0, 1, 2], 1, p=[0.5, 0.3, 0.2]))
data_shape_without_modality = list(data['image'].shape)[:-1]
random_indices = np.random.permutation([0, 1, 2])
......
......@@ -383,29 +383,26 @@ class ImageReader(Layer):
raise
def _filename_to_image_list(file_list, mod_dict, data_param):
def _filename_to_image_list(file_list, mod_dict, data_param, num_threads=10):
"""
Converting a list of filenames to a list of image objects,
Properties (e.g. interp_order) are added to each object
"""
from functools import partial
import multiprocessing
volume_list = []
valid_idx = []
for idx in range(len(file_list)):
valid_idxs = []
tf.logging.info(f'Converting a list of filenames to a list of image objects with: num_threads:{num_threads}')
pool = multiprocessing.Pool(num_threads) # Should work if num_threads = 1
func = partial(_create_image_multiprocessing_wrapper, file_list=file_list, data_param=data_param, mod_dict=mod_dict)
for idx, (_dict, valid_idx) in enumerate(pool.imap_unordered(func=func, iterable=range(len(file_list)))):
# create image instance for each subject
print_progress_bar(idx, len(file_list),
prefix='reading datasets headers',
decimals=1, length=10, fill='*')
# combine fieldnames and volumes as a dictionary
_dict = {}
for field, modalities in mod_dict.items():
_dict[field] = _create_image(
file_list, idx, modalities, data_param)
# skipping the subject if there're missing image components
if _dict and None not in list(_dict.values()):
volume_list.append(_dict)
valid_idx.append(idx)
valid_idxs.append(valid_idx)
if not volume_list:
tf.logging.fatal(
......@@ -421,6 +418,13 @@ def _filename_to_image_list(file_list, mod_dict, data_param):
raise IOError
return volume_list, file_list.iloc[valid_idx]
def _create_image_multiprocessing_wrapper(idx, file_list, data_param, mod_dict):
_dict = {}
for field, modalities in mod_dict.items():
_dict[field] = _create_image(file_list, idx, modalities, data_param)
# skipping the subject if there're missing image components
return _dict, idx
def _create_image(file_list, idx, modalities, data_param):
"""
......
......@@ -2,6 +2,7 @@
from __future__ import absolute_import, print_function, division
import os
import multiprocessing
import numpy as np
import tensorflow as tf
......@@ -18,6 +19,7 @@ class DiscreteLabelNormalisationLayer(DataDependentLayer, Invertible):
image_name,
modalities,
model_filename=None,
num_threads=1,
name='label_norm'):
super(DiscreteLabelNormalisationLayer, self).__init__(name=name)
......@@ -25,6 +27,7 @@ class DiscreteLabelNormalisationLayer(DataDependentLayer, Invertible):
# modalities are listed in self.modalities
self.image_name = image_name
self.modalities = None
self.num_threads = num_threads
if isinstance(modalities, (list, tuple)):
if len(modalities) > 1:
raise NotImplementedError(
......@@ -130,11 +133,12 @@ class DiscreteLabelNormalisationLayer(DataDependentLayer, Invertible):
self.image_name,
self.modalities,
len(self.label_map[self.key[0]])))
print(self.label_map)
return
tf.logging.info(
"Looking for the set of unique discrete labels from input {}"
" using {} subjects".format(self.image_name, len(image_list)))
label_map = find_set_of_labels(image_list, self.image_name, self.key)
label_map = find_set_of_labels(image_list, self.image_name, self.key, num_threads=self.num_threads)
# merging trained_mapping dict and self.mapping dict
self.label_map.update(label_map)
all_maps = hs.read_mapping_file(self.model_file)
......@@ -142,23 +146,27 @@ class DiscreteLabelNormalisationLayer(DataDependentLayer, Invertible):
hs.write_all_mod_mapping(self.model_file, all_maps)
def find_set_of_labels(image_list, field, output_key):
def get_unique_labels(image, field):
assert field in image, \
"label normalisation layer requires {} input, " \
"however it is not provided in the config file.\n" \
"Please consider setting " \
"label_normalisation to False.".format(field)
unique_label = np.unique(image[field].get_data())
return unique_label
def find_set_of_labels(image_list, field, output_key, num_threads=1):
from functools import partial
tf.logging.info(f'Finding set of labels with: num_threads:{num_threads}')
pool = multiprocessing.Pool(num_threads) # Should work if num_threads = 1
label_set = set()
if field in image_list[0] :
for idx, image in enumerate(image_list):
assert field in image, \
"label normalisation layer requires {} input, " \
"however it is not provided in the config file.\n" \
"Please consider setting " \
"label_normalisation to False.".format(field)
for idx, unique_label in enumerate(pool.imap_unordered(partial(get_unique_labels, field=field), image_list)):
# for idx, image in enumerate(image_list):
print_progress_bar(idx, len(image_list),
prefix='searching unique labels from files',
decimals=1, length=10, fill='*')
unique_label = np.unique(image[field].get_data())
if len(unique_label) > 500 or len(unique_label) <= 1:
tf.logging.warning(
'unusual discrete values: number of unique '
'labels to normalise %s', len(unique_label))
label_set.update(set(unique_label))
label_set = list(label_set)
label_set.sort()
......@@ -170,4 +178,5 @@ def find_set_of_labels(image_list, field, output_key):
tf.logging.fatal("unable to create mappings keys: %s, image name %s",
output_key, field)
raise
pool.close()
return mapping_from_to
......@@ -36,7 +36,9 @@ class LossFunction(Layer):
self._loss_func_params = \
loss_func_params if loss_func_params is not None else dict()
if self._data_loss_func.__name__.startswith('cross_entropy'):
data_loss_function_name = self._data_loss_func.__name__
if data_loss_function_name.startswith('cross_entropy')\
or 'xent' in data_loss_function_name:
tf.logging.info(
'Cross entropy loss function calls '
'tf.nn.sparse_softmax_cross_entropy_with_logits '
......@@ -205,11 +207,11 @@ def generalised_dice_loss(prediction,
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction)[-1])
if weight_map is not None:
n_classes = prediction.shape[1].value
num_classes = prediction.shape[1].value
# weight_map_nclasses = tf.reshape(
# tf.tile(weight_map, [n_classes]), prediction.get_shape())
# tf.tile(weight_map, [num_classes]), prediction.get_shape())
weight_map_nclasses = tf.tile(
tf.expand_dims(tf.reshape(weight_map, [-1]), 1), [1, n_classes])
tf.expand_dims(tf.reshape(weight_map, [-1]), 1), [1, num_classes])
ref_vol = tf.sparse_reduce_sum(
weight_map_nclasses * one_hot, reduction_axes=[0])
......@@ -239,7 +241,7 @@ def generalised_dice_loss(prediction,
# generalised_dice_denominator = \
# tf.reduce_sum(tf.multiply(weights, seg_vol + ref_vol)) + 1e-6
generalised_dice_denominator = tf.reduce_sum(
tf.multiply(weights, tf.maximum(seg_vol + ref_vol, 1)))
tf.multiply(weights, tf.maximum(seg_vol + ref_vol, 1)))
generalised_dice_score = \
generalised_dice_numerator / generalised_dice_denominator
generalised_dice_score = tf.where(tf.is_nan(generalised_dice_score), 1.0,
......@@ -247,6 +249,42 @@ def generalised_dice_loss(prediction,
return 1 - generalised_dice_score
def dice_plus_xent_loss(prediction, ground_truth, weight_map=None):
"""
Function to calculate the loss used in https://arxiv.org/pdf/1809.10486.pdf,
no-new net, Isenseee et al (used to win the Medical Imaging Decathlon).
It is the sum of the cross-entropy and the Dice-loss.
:param prediction: the logits
:param ground_truth: the segmentation ground truth
:param weight_map:
:return: the loss (cross_entropy + Dice)
"""
if weight_map is not None:
raise NotImplementedError
num_classes = tf.shape(prediction)[-1]
prediction = tf.cast(prediction, tf.float32)
loss_xent = cross_entropy(prediction, ground_truth)
# Dice as according to the paper:
one_hot = labels_to_one_hot(ground_truth, num_classes=num_classes)
softmax_of_logits = tf.nn.softmax(prediction)
dice_numerator = -2.0 * tf.sparse_reduce_sum(one_hot * softmax_of_logits,
reduction_axes=[0])
dice_denominator = tf.reduce_sum(softmax_of_logits, reduction_indices=[0]) + \
tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
epsilon_denominator = 0.00001
loss_dice = dice_numerator / (dice_denominator + epsilon_denominator)
return loss_dice + loss_xent
def sensitivity_specificity_loss(prediction,
ground_truth,
weight_map=None,
......@@ -344,7 +382,7 @@ def wasserstein_disagreement_map(
assert M is not None, "Distance matrix is required."
# pixel-wise Wassertein distance (W) between flat_pred_proba and flat_labels
# wrt the distance matrix on the label space M
n_classes = prediction.shape[1].value
num_classes = prediction.shape[1].value
ground_truth.set_shape(prediction.shape)
unstack_labels = tf.unstack(ground_truth, axis=-1)
unstack_labels = tf.cast(unstack_labels, dtype=tf.float64)
......@@ -354,8 +392,8 @@ def wasserstein_disagreement_map(
# "unstacked pred" ,unstack_pred)
# W is a weighting sum of all pairwise correlations (pred_ci x labels_cj)
pairwise_correlations = []
for i in range(n_classes):
for j in range(n_classes):
for i in range(num_classes):
for j in range(num_classes):
pairwise_correlations.append(
M[i, j] * tf.multiply(unstack_pred[i], unstack_labels[j]))
wass_dis_map = tf.add_n(pairwise_correlations)
......@@ -382,7 +420,7 @@ def generalised_wasserstein_dice_loss(prediction,
tf.logging.warning('Weight map specified but not used.')
prediction = tf.cast(prediction, tf.float32)
n_classes = prediction.shape[1].value
num_classes = prediction.shape[1].value
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction)[-1])
one_hot = tf.sparse_tensor_to_dense(one_hot)
......@@ -395,7 +433,7 @@ def generalised_wasserstein_dice_loss(prediction,
# compute generalisation of true positives for multi-class seg
one_hot = tf.cast(one_hot, dtype=tf.float64)
true_pos = tf.reduce_sum(
tf.multiply(tf.constant(M[0, :n_classes], dtype=tf.float64), one_hot),
tf.multiply(tf.constant(M[0, :num_classes], dtype=tf.float64), one_hot),
axis=1)
true_pos = tf.reduce_sum(tf.multiply(true_pos, 1. - delta), axis=0)
WGDL = 1. - (2. * true_pos) / (2. * true_pos + all_error)
......@@ -423,9 +461,9 @@ def dice(prediction, ground_truth, weight_map=None):
one_hot = labels_to_one_hot(ground_truth, tf.shape(prediction)[-1])
if weight_map is not None:
n_classes = prediction.shape[1].value
num_classes = prediction.shape[1].value
weight_map_nclasses = tf.tile(tf.expand_dims(
tf.reshape(weight_map, [-1]), 1), [1, n_classes])
tf.reshape(weight_map, [-1]), 1), [1, num_classes])
dice_numerator = 2.0 * tf.sparse_reduce_sum(
weight_map_nclasses * one_hot * prediction, reduction_axes=[0])
dice_denominator = \
......@@ -442,7 +480,7 @@ def dice(prediction, ground_truth, weight_map=None):
epsilon_denominator = 0.00001
dice_score = dice_numerator / (dice_denominator + epsilon_denominator)
# dice_score.set_shape([n_classes])
# dice_score.set_shape([num_classes])
# minimising (1 - dice_coefficients)
return 1.0 - tf.reduce_mean(dice_score)
......@@ -463,9 +501,9 @@ def dice_nosquare(prediction, ground_truth, weight_map=None):
# dice
if weight_map is not None:
n_classes = prediction.shape[1].value
num_classes = prediction.shape[1].value
weight_map_nclasses = tf.tile(tf.expand_dims(
tf.reshape(weight_map, [-1]), 1), [1, n_classes])
tf.reshape(weight_map, [-1]), 1), [1, num_classes])
dice_numerator = 2.0 * tf.sparse_reduce_sum(
weight_map_nclasses * one_hot * prediction, reduction_axes=[0])
dice_denominator = \
......@@ -481,7 +519,7 @@ def dice_nosquare(prediction, ground_truth, weight_map=None):
epsilon_denominator = 0.00001
dice_score = dice_numerator / (dice_denominator + epsilon_denominator)
# dice_score.set_shape([n_classes])
# dice_score.set_shape([num_classes])
# minimising (1 - dice_coefficients)
return 1.0 - tf.reduce_mean(dice_score)
......
......@@ -9,6 +9,7 @@ from __future__ import print_function
import argparse
import math
import os
import re
import shutil
import tarfile
import tempfile
......@@ -58,8 +59,8 @@ def download(example_ids,
return False
# Check if the server is running by looking for a known file
remote_base_url_test = gitlab_raw_file_url(
global_config.get_download_server_url(), 'README.md')
remote_base_url_test = raw_file_url(
global_config.get_download_server_url())
server_ok = url_exists(remote_base_url_test)
if verbose:
print("Accessing: {}".format(global_config.get_download_server_url()))
......@@ -94,14 +95,13 @@ def download_file(url, download_path):
:param url: URL of the file to download
:param download_path: location where the file should be saved
"""
# Extract the filename from the URL
parsed = urlparse(url)
filename = os.path.basename(parsed.path)
filename = os.path.basename(download_path)
# Ensure the output directory exists
if not os.path.exists(download_path):
os.makedirs(download_path)
output_directory = os.path.dirname(download_path)
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# Get a temporary file path for the compressed file download
temp_folder = tempfile.mkdtemp()
......@@ -111,8 +111,7 @@ def download_file(url, download_path):
urlretrieve(url, downloaded_file, reporthook=progress_bar_wrapper)
# Move the file to the destination folder
destination_path = os.path.join(download_path, filename)
shutil.move(downloaded_file, destination_path)
shutil.move(downloaded_file, download_path)
shutil.rmtree(temp_folder, ignore_errors=True)
......@@ -338,7 +337,9 @@ class ConfigStoreCache(object):
Returns the full path to the locally cached configuration file
"""
return os.path.join(self._cache_folder, example_id + CONFIG_FILE_EXT)
return os.path.join(self._cache_folder,
example_id + '_main'+ CONFIG_FILE_EXT)
# return os.path.join(self._cache_folder, example_id + CONFIG_FILE_EXT)
def get_local_cache_folder(self):
"""
......@@ -351,7 +352,6 @@ class ConfigStoreCache(object):
"""
Returns the local configuration file for this example_id
"""
config_filename = self.get_local_path(example_id)
parser = NiftyNetLaunchConfig()
......@@ -392,7 +392,7 @@ class RemoteProxy(object):
"""
download_file(self._remote.get_url(example_id),
self._cache.get_local_cache_folder())
self._cache.get_local_path(example_id))
def get_download_params(self, example_id):
"""
......@@ -428,17 +428,21 @@ class RemoteConfigStore(object):
"""
Gets the URL for the record for this example_id
"""
return raw_file_url(self._base_url, example_id)
return gitlab_raw_file_url(self._base_url,
example_id + CONFIG_FILE_EXT)
def gitlab_raw_file_url(base_url, file_name):
def raw_file_url(base_url, example_id=None):
"""
Returns the url for the raw file on a GitLab server
"""
return base_url + '/raw/new_dataset_api/' + file_name
_branch_name = '5-reorganising-with-lfs'
if not example_id:
return '{}/raw/{}/README.md'.format(base_url, _branch_name)
example_id = re.sub('_model_zoo', '', example_id, 1)
return '{}/raw/{}/{}/main{}'.format(
base_url, _branch_name, example_id, CONFIG_FILE_EXT)
# return base_url + '/raw/new_dataset_api/' + file_name
# return base_url + '/raw/master/' + file_name
# return base_url + '/raw/revising-config/' + file_name
......
......@@ -339,6 +339,13 @@ def add_network_args(parser):
type=int,
default=5)
parser.add_argument(
"--num_threads",
help="Set number of threads used for image loading",
metavar='',
type=int,
default=5)
parser.add_argument(
"--multimod_foreground_type",
choices=list(
......
......@@ -7,6 +7,60 @@ import tensorflow as tf
from niftynet.layer.loss_segmentation import LossFunction, labels_to_one_hot
class DicePlusXEntTest(tf.test.TestCase):
def test_dice_plus(self):
with self.test_session():
predicted = tf.constant(
[[0, 9999], [9999, 0], [9999, 0], [9999, 0]],
dtype=tf.float32, name='predicted')
labels = tf.constant([1, 0, 0, 0], dtype=tf.int16, name='labels')
predicted, labels = [tf.expand_dims(x, axis=0) for x in (predicted, labels)]
test_loss_func = LossFunction(2, loss_type='DicePlusXEnt', softmax=False)
loss_value = test_loss_func(predicted, labels)
# softmax of zero, Dice loss of -1, so sum \approx -1
self.assertAllClose(loss_value.eval(), -1.0, atol=1e-3)
def test_dice_plus_multilabel(self):
with self.test_session():
predicted = tf.constant(
[[0, 0, 9999], [9999, 0, 0], [0, 9999, 0], [9999, 0, 0]],
dtype=tf.float32, name='predicted')
labels = tf.constant([2, 0, 1, 0], dtype=tf.int16, name='labels')
predicted, labels = [tf.expand_dims(x, axis=0) for x in (predicted, labels)]
test_loss_func = LossFunction(3, loss_type='DicePlusXEnt', softmax=False)
loss_value = test_loss_func(predicted, labels)
# cross-ent of zero, Dice loss of -1, so sum \approx -1
self.assertAllClose(loss_value.eval(), -1.0, atol=1e-3)
def test_dice_plus_non_zeros(self):
with self.test_session():
predicted = tf.constant(
[[0, 9999, 9999], [9999, 0, 0], [0, 9999, 9999], [9999, 0, 0]],
dtype=tf.float32, name='predicted')
labels = tf.constant([2, 0, 1, 0], dtype=tf.int16, name='labels')
predicted, labels = [tf.expand_dims(x, axis=0) for x in (predicted, labels)]
test_loss_func = LossFunction(3, loss_type='DicePlusXEnt', softmax=False)
loss_value = test_loss_func(predicted, labels)
# cross-ent of mean(ln(2), 0, 0, ln(2)) = .5*ln(2)
# Dice loss of -mean(1, .5, .5)=-2/3
self.assertAllClose(loss_value.eval(), .5 * np.log(2) - 2. / 3., atol=1e-3)
def test_dice_plus_wrong_softmax(self):
with self.test_session():
predicted = tf.constant(
[[0, 9999, 9999], [9999, 0, 0], [0, 9999, 9999], [9999, 0, 0]],
dtype=tf.float32, name='predicted')
labels = tf.constant([2, 0, 1, 0], dtype=tf.int16, name='labels')
predicted, labels = [tf.expand_dims(x, axis=0) for x in (predicted, labels)]
test_loss_func = LossFunction(3, loss_type='DicePlusXEnt', softmax=True)
loss_value = test_loss_func(predicted, labels)
# cross-ent of mean(ln(2), 0, 0, ln(2)) = .5*ln(2)
# Dice loss of -mean(1, .5, .5)=-2/3
self.assertAllClose(loss_value.eval(), .5 * np.log(2) - 2. / 3., atol=1e-3)
class OneHotTester(tf.test.TestCase):
def test_vs_tf_onehot(self):
with self.test_session():
......@@ -17,13 +71,13 @@ class OneHotTester(tf.test.TestCase):
def test_one_hot(self):
ref = np.asarray(
[[[ 0., 1., 0., 0., 0.], [ 0., 0., 1., 0., 0.]],
[[ 0., 0., 0., 1., 0.], [ 0., 0., 0., 0., 1.]]],
[[[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]],
[[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]],
dtype=np.float32)
with self.test_session():
labels = tf.constant([[1, 2], [3, 4]])
#import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
one_hot = tf.sparse_tensor_to_dense(
labels_to_one_hot(labels, 5)).eval()
self.assertAllEqual(one_hot, ref)
......
......@@ -47,38 +47,39 @@ interp_order = 0
############################## system configuration sections
[SYSTEM]
cuda_devices = ""
num_threads = 10
num_gpus = 1
model_dir = /home/tom/phd/3D_HeMIS_parcellation_lesions/v1
queue_length = 80
model_dir = /home/tom/data/3D_HeMIS_parcellation_lesions/v1/v1/
num_threads = 30
[NETWORK]
name = niftynet.contrib.pimms.multitask_pimms_3D.MultitaskPIMMS3D
activation_function = relu
decay = 1e-4
decay = 0.0
reg_type = L2
batch_size = 2
batch_size = 1
volume_padding_size=(0,0,0)
normalisation = False
histogram_ref_file = /home/tom/phd/NiftyNet-dev/NiftyNet/demos/BRATS17/label_mapping_whole_tumor.txt
histogram_ref_file = /home/tom/phd/NiftyNet-dev/NiftyNet/brain_parcellation_ref_file.txt
normalise_foreground_only = False
foreground_type = threshold_plus
num_threads = 30
multimod_foreground_type = and
window_sampling = uniform
[TRAINING]
optimiser = adam
sample_per_volume = 1
lr = 3e-4
loss_type = Dice
starting_iter = 0
lr = 1e-2
loss_type = DicePlusXEnt
starting_iter = 9421
save_every_n = 500
max_iter = 10000
max_checkpoints = 20
validation_every_n = 100
validation_max_iter = 22
exclude_fraction_for_validation = 0.1
tensorboard_every_n = 1
exclude_fraction_for_inference = 0.6
tensorboard_every_n = 10
############################ custom configuration sections
[SEGMENTATION]
......@@ -86,7 +87,7 @@ image = Flair,T1,T2
label = label,Parcellation
output_prob = False
num_classes = 2
label_normalisation = False
label_normalisation = True
[EVALUATION]
......