Commit 29f098d9 by Wenqi Li

Merge branch 'fix-network-output-variable-name' into 'dev'

Update simulator gan and fix network output variable name

See merge request !75
parents 290ea1c9 55149b95
Pipeline #8531 passed with stages
in 2 minutes 39 seconds
......@@ -4,7 +4,7 @@ from niftynet.application.base_application import BaseApplication
from niftynet.engine.application_factory import ApplicationNetFactory
from niftynet.engine.application_factory import OptimiserFactory
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.application_variables import NETORK_OUTPUT
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.sampler_grid import GridSampler
from niftynet.engine.sampler_uniform import UniformSampler
......@@ -188,10 +188,10 @@ class BRATSApp(BaseApplication):
outputs_collector.add_to_collection(
var=net_out, name='window',
average_over_devices=False, collection=NETORK_OUTPUT)
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=data_dict['image_location'], name='location',
average_over_devices=False, collection=NETORK_OUTPUT)
average_over_devices=False, collection=NETWORK_OUTPUT)
init_aggregator = \
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
init_aggregator()
......
......@@ -30,12 +30,8 @@ model_dir = ./models/model_us_simulator_gan_demo
queue_length = 20
[NETWORK]
#name = simulator_gan
name = niftynet.network.simulator_gan.SimulatorGAN
activation_function = prelu
batch_size = 36
#decay = 1e-7
reg_type = L2
histogram_ref_file = ./example_volumes/monomodal_parcellation/standardisation_models.txt
norm_type = percentile
......@@ -43,12 +39,10 @@ cutoff = (0.01, 0.99)
[TRAINING]
sample_per_volume = 1
#rotation_angle = (-10.0, 10.0)
#scaling_percentage = (-10.0, 10.0)
lr = 0.0001
loss_type = CrossEntropy
starting_iter = 0
save_every_n = 2000
save_every_n = 1000
max_iter = 10000
max_checkpoints = 20
......
......@@ -4,7 +4,7 @@ from niftynet.application.base_application import BaseApplication
from niftynet.engine.application_factory import ApplicationNetFactory
from niftynet.engine.application_factory import OptimiserFactory
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.application_variables import NETORK_OUTPUT
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.windows_aggregator_identity import WindowAsImageAggregator
from niftynet.engine.sampler_linear_interpolate import LinearInterpolateSampler
......@@ -163,16 +163,16 @@ class AutoencoderApplication(BaseApplication):
outputs_collector.add_to_collection(
var=data_dict['image_location'], name='location',
average_over_devices=True, collection=NETORK_OUTPUT)
average_over_devices=True, collection=NETWORK_OUTPUT)
if self._infer_type == 'encode-decode':
outputs_collector.add_to_collection(
var=net_output[2], name='generated_image',
average_over_devices=True, collection=NETORK_OUTPUT)
average_over_devices=True, collection=NETWORK_OUTPUT)
if self._infer_type == 'encode':
outputs_collector.add_to_collection(
var=net_output[7], name='embedded',
average_over_devices=True, collection=NETORK_OUTPUT)
average_over_devices=True, collection=NETWORK_OUTPUT)
self.output_decoder = WindowAsImageAggregator(
image_reader=self.reader,
......@@ -196,7 +196,7 @@ class AutoencoderApplication(BaseApplication):
outputs_collector.add_to_collection(
var=decoder_output, name='generated_image',
average_over_devices=True, collection=NETORK_OUTPUT)
average_over_devices=True, collection=NETWORK_OUTPUT)
self.output_decoder = WindowAsImageAggregator(
image_reader=None,
output_path=self.action_param.save_seg_dir)
......@@ -217,10 +217,10 @@ class AutoencoderApplication(BaseApplication):
outputs_collector.add_to_collection(
var=decoder_output, name='generated_image',
average_over_devices=True, collection=NETORK_OUTPUT)
average_over_devices=True, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=data_dict['feature_location'], name='location',
average_over_devices=True, collection=NETORK_OUTPUT)
average_over_devices=True, collection=NETWORK_OUTPUT)
self.output_decoder = WindowAsImageAggregator(
image_reader=self.reader,
output_path=self.action_param.save_seg_dir)
......
......@@ -6,7 +6,7 @@ from niftynet.application.base_application import BaseApplication
from niftynet.engine.application_factory import ApplicationNetFactory
from niftynet.engine.application_factory import OptimiserFactory
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.application_variables import NETORK_OUTPUT
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.windows_aggregator_identity import WindowAsImageAggregator
from niftynet.engine.sampler_random_vector import RandomVectorSampler
......@@ -215,12 +215,12 @@ class GANApplication(BaseApplication):
var=net_output[0],
name='image',
average_over_devices=False,
collection=NETORK_OUTPUT)
collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=conditioning_dict['conditioning_location'],
name='location',
average_over_devices=False,
collection=NETORK_OUTPUT)
collection=NETWORK_OUTPUT)
self.output_decoder = WindowAsImageAggregator(
image_reader=self.reader,
......
......@@ -4,7 +4,7 @@ from niftynet.application.base_application import BaseApplication
from niftynet.engine.application_factory import ApplicationNetFactory
from niftynet.engine.application_factory import OptimiserFactory
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.application_variables import NETORK_OUTPUT
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator
from niftynet.engine.windows_aggregator_resize import ResizeSamplesAggregator
......@@ -246,10 +246,10 @@ class SegmentationApplication(BaseApplication):
outputs_collector.add_to_collection(
var=net_out, name='window',
average_over_devices=False, collection=NETORK_OUTPUT)
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=data_dict['image_location'], name='location',
average_over_devices=False, collection=NETORK_OUTPUT)
average_over_devices=False, collection=NETWORK_OUTPUT)
init_aggregator = \
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
init_aggregator()
......
......@@ -21,7 +21,7 @@ import tensorflow as tf
from niftynet.engine.application_factory import ApplicationFactory
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.application_variables import GradientsCollector
from niftynet.engine.application_variables import NETORK_OUTPUT
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.engine.application_variables import OutputsCollector
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.application_variables import \
......@@ -313,9 +313,9 @@ class ApplicationDriver(object):
# variables to the graph
vars_to_run = dict(train_op=train_op)
vars_to_run[CONSOLE], vars_to_run[NETORK_OUTPUT] = \
vars_to_run[CONSOLE], vars_to_run[NETWORK_OUTPUT] = \
self.outputs_collector.variables(CONSOLE), \
self.outputs_collector.variables(NETORK_OUTPUT)
self.outputs_collector.variables(NETWORK_OUTPUT)
if self.tensorboard_every_n > 0 and \
(iter_i % self.tensorboard_every_n == 0):
# adding tensorboard summary
......@@ -326,7 +326,7 @@ class ApplicationDriver(object):
graph_output = sess.run(vars_to_run)
# process graph outputs
self.app.interpret_output(graph_output[NETORK_OUTPUT])
self.app.interpret_output(graph_output[NETWORK_OUTPUT])
console_str = self._console_vars_to_str(graph_output[CONSOLE])
summary = graph_output.get(TF_SUMMARIES, {})
if summary:
......@@ -352,15 +352,15 @@ class ApplicationDriver(object):
# build variables to run
vars_to_run = dict()
vars_to_run[NETORK_OUTPUT], vars_to_run[CONSOLE] = \
self.outputs_collector.variables(NETORK_OUTPUT), \
vars_to_run[NETWORK_OUTPUT], vars_to_run[CONSOLE] = \
self.outputs_collector.variables(NETWORK_OUTPUT), \
self.outputs_collector.variables(CONSOLE)
# evaluate the graph variables
graph_output = sess.run(vars_to_run)
# process the graph outputs
if not self.app.interpret_output(graph_output[NETORK_OUTPUT]):
if not self.app.interpret_output(graph_output[NETWORK_OUTPUT]):
tf.logging.info('processed all batches.')
loop_status['all_saved_flag'] = True
break
......
......@@ -17,7 +17,7 @@ from niftynet.utilities.restore_initializer import restore_initializer
from niftynet.utilities.util_common import look_up_operations
RESTORABLE = 'NiftyNetObjectsToRestore'
NETORK_OUTPUT = 'niftynetout'
NETWORK_OUTPUT = 'niftynetout'
CONSOLE = 'niftynetconsole'
TF_SUMMARIES = tf.GraphKeys.SUMMARIES
SUPPORTED_SUMMARY = {'scalar': tf.summary.scalar,
......@@ -82,7 +82,7 @@ class OutputsCollector(object):
"""
Collect all tf.Tensor object, to be evaluated by tf.Session.run()
These objects are grouped into
NETORK_OUTPUT: to be decoded by an aggregator
NETWORK_OUTPUT: to be decoded by an aggregator
CONSOLE: to be printed on command line
TF_SUMMARIES: to be added to tensorboard visualisation
"""
......@@ -151,7 +151,7 @@ class OutputsCollector(object):
:param name name of the variable (for displaying purposes)
:param average_over_devices
:param collection: in choices of
[CONSOLE, TF_SUMMARIES, NETORK_OUTPUT]
[CONSOLE, TF_SUMMARIES, NETWORK_OUTPUT]
:param summary_type if adding to TF_SUMMARIES, there are
a few possible ways to visualise the Tensor value
see SUPPORTED_SUMMARY
......@@ -159,7 +159,7 @@ class OutputsCollector(object):
"""
if collection == CONSOLE:
self._add_to_console(var, name, average_over_devices)
elif collection == NETORK_OUTPUT:
elif collection == NETWORK_OUTPUT:
self._add_to_network_output(var, name, average_over_devices)
elif collection == TF_SUMMARIES:
self._add_to_tf_summary(
......@@ -172,14 +172,14 @@ class OutputsCollector(object):
"""
get tf.Tensors to be evaulated by tf.Session().run()
:param collection: in choices of
[CONSOLE, TF_SUMMARIES, NETORK_OUTPUT]
[CONSOLE, TF_SUMMARIES, NETWORK_OUTPUT]
:return: a variable dictionary
"""
if collection == CONSOLE:
return self.console_vars
elif collection == TF_SUMMARIES:
return self._merge_op if self._merge_op is not None else {}
elif collection == NETORK_OUTPUT:
elif collection == NETWORK_OUTPUT:
return self.output_vars
else:
tf.logging.fatal("unknown output %s", collection)
......
......@@ -77,7 +77,7 @@ def load_image(filename):
# continue to next loader
pass
raise nib.filebasedimages.ImageFileError(
'No loader could load the file') # Throw last error
'No loader could load the file {}'.format(filename))
def correct_image_if_necessary(img):
......
......@@ -20,8 +20,8 @@ def selu(x, name):
def leakyRelu(x, name):
alpha = 0.01
return tf.maximum(alpha * x, x, name)
half_alpha = 0.01
return (0.5 + half_alpha) * x + (0.5 - half_alpha) * abs(x)
SUPPORTED_OP = {'relu': tf.nn.relu,
......
......@@ -15,7 +15,8 @@ def default_w_initializer():
def _initializer(shape, dtype, partition_info):
stddev = np.sqrt(2.0 / np.prod(shape[:-1]))
from tensorflow.python.ops import random_ops
return random_ops.truncated_normal(shape, 0.0, stddev, dtype=tf.float32)
return random_ops.truncated_normal(
shape, 0.0, stddev, dtype=tf.float32)
# return tf.truncated_normal_initializer(
# mean=0.0, stddev=stddev, dtype=tf.float32)
......
......@@ -6,8 +6,6 @@ import tensorflow as tf
from niftynet.layer.base_layer import TrainableLayer
# import niftynet.engine.logging as logging
class GANImageBlock(TrainableLayer):
def __init__(self,
generator,
......@@ -25,25 +23,22 @@ class GANImageBlock(TrainableLayer):
conditioning,
is_training):
shape_to_generate = training_image.get_shape().as_list()[1:]
fake_image = self._generator(
random_source, shape_to_generate, conditioning, is_training)
fake_logits = self._discriminator(
fake_image, conditioning, is_training)
fake_image = self._generator(random_source,
shape_to_generate,
conditioning,
is_training)
fake_logits = self._discriminator(fake_image,
conditioning,
is_training)
if self.clip:
with tf.name_scope('clip_real_images'):
training_image = tf.maximum(
-self.clip,
tf.minimum(self.clip, training_image))
real_logits = self._discriminator(
training_image, conditioning, is_training)
# with tf.name_scope('summaries_images'):
# if len(fake_image.get_shape()) - 2 == 3:
# logging.image3_axial('fake', (fake_image / 2 + 1) * 127, 2, [logging.LOG])
# logging.image3_axial('real', tf.maximum(0., tf.minimum(255., (training_image / 2 + 1) * 127)), 2, [logging.LOG])
# if len(fake_image.get_shape()) - 2 == 2:
# tf.summary.fake_image('fake', (fake_image / 2 + 1) * 127, 2, [logging.LOG])
# tf.summary.fake_image('real', tf.maximum(0., tf.minimum(255., (training_image / 2 + 1) * 127)), 2, [logging.LOG])
real_logits = self._discriminator(training_image,
conditioning,
is_training)
return fake_image, real_logits, fake_logits
......
......@@ -3,7 +3,7 @@ from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.engine.application_variables import NETORK_OUTPUT
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.application_variables import OutputsCollector
from niftynet.network.toynet import ToyNet
......@@ -31,7 +31,7 @@ class OutputCollectorTest(tf.test.TestCase):
average_over_devices=False)
collector.add_to_collection(name='bar',
var=bar,
collection=NETORK_OUTPUT,
collection=NETWORK_OUTPUT,
average_over_devices=False)
self.assertDictEqual(collector.console_vars,
{'image': image, 'foo': foo})
......@@ -72,11 +72,11 @@ class OutputCollectorTest(tf.test.TestCase):
foo = tf.zeros([2, 2])
collector.add_to_collection(name='image',
var=image,
collection=NETORK_OUTPUT,
collection=NETWORK_OUTPUT,
average_over_devices=False)
collector.add_to_collection(name='foo',
var=foo,
collection=NETORK_OUTPUT,
collection=NETWORK_OUTPUT,
average_over_devices=False)
self.assertDictEqual(collector.output_vars,
{'image': image, 'foo': foo})
......@@ -91,15 +91,15 @@ class OutputCollectorTest(tf.test.TestCase):
bar = tf.zeros([42])
collector.add_to_collection(name='image',
var=image,
collection=NETORK_OUTPUT,
collection=NETWORK_OUTPUT,
average_over_devices=False)
collector.add_to_collection(name='foo',
var=foo,
collection=NETORK_OUTPUT,
collection=NETWORK_OUTPUT,
average_over_devices=False)
collector.add_to_collection(name='bar',
var=bar,
collection=NETORK_OUTPUT,
collection=NETWORK_OUTPUT,
average_over_devices=True)
self.assertEqual(
set(collector.output_vars),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment