Commit 436cea71 authored by Felix Bragman's avatar Felix Bragman

multi-task vgg16

parent aac3fc2b
Pipeline #12594 failed with stages
in 9 seconds
......@@ -39,31 +39,86 @@ class MT1_VGG16Net(BaseNet):
name=name)
self.layers = [
{'name': 'layer_1', 'n_features': 64, 'kernel_size': 3, 'repeat': 2},
{'name': 'layer_1', 'n_features': 64*2, 'kernel_size': 3, 'repeat': 2},
{'name': 'maxpool_1'},
{'name': 'layer_2', 'n_features': 128, 'kernel_size': 3, 'repeat': 2},
{'name': 'layer_2', 'n_features': 128*2, 'kernel_size': 3, 'repeat': 2},
{'name': 'maxpool_2'},
{'name': 'layer_3', 'n_features': 256, 'kernel_size': 3, 'repeat': 3},
{'name': 'layer_3', 'n_features': 256*2, 'kernel_size': 3, 'repeat': 3},
{'name': 'maxpool_3'},
{'name': 'layer_4', 'n_features': 512, 'kernel_size': 3, 'repeat': 3},
{'name': 'layer_4', 'n_features': 512*2, 'kernel_size': 3, 'repeat': 3},
{'name': 'maxpool_4'},
{'name': 'layer_5', 'n_features': 512, 'kernel_size': 3, 'repeat': 3},
{'name': 'layer_5', 'n_features': 512*2, 'kernel_size': 3, 'repeat': 3},
{'name': 'maxpool_5'}]
self.task1_layers = [
{'name': 'fc_1', 'n_features': 4096},
{'name': 'fc_2', 'n_features': 4096},
{'name': 'fc_1', 'n_features': 4096*2},
{'name': 'fc_2', 'n_features': 4096*2},
{'name': 'fc_3', 'n_features': task1_classes}]
self.task2_layers = [
{'name': 'fc_1', 'n_features': 4096},
{'name': 'fc_2', 'n_features': 4096},
{'name': 'fc_1', 'n_features': 4096*2},
{'name': 'fc_2', 'n_features': 4096*2},
{'name': 'fc_3', 'n_features': task2_classes}]
def layer_op(self, images, is_training=True, layer_id=-1, **unused_kwargs):
#assert layer_util.check_spatial_dims(
# images, lambda x: x % 224 == 0)
# main network graph
flow, layer_instances = self.create_main_network_graph(images, is_training)
# add task 1 output
flow_t1 = flow
for layer_iter, layer in enumerate(self.task1_layers):
if layer_iter == len(self.layers)-1:
fc_layer = FullyConnectedLayer(
n_output_chns=layer['n_features'],
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
)
task1_out = fc_layer(flow_t1)
layer_instances.append((fc_layer, task1_out))
else:
fc_layer = FullyConnectedLayer(
n_output_chns=layer['n_features'],
acti_func=self.acti_func,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
)
flow_t1 = fc_layer(flow_t1)
layer_instances.append((fc_layer, flow_t1))
# add task 1 output
flow_t2 = flow
for layer_iter, layer in enumerate(self.task2_layers):
if layer_iter == len(self.layers) - 1:
fc_layer = FullyConnectedLayer(
n_output_chns=layer['n_features'],
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
)
task2_out = fc_layer(flow_t2)
layer_instances.append((fc_layer, task2_out))
else:
fc_layer = FullyConnectedLayer(
n_output_chns=layer['n_features'],
acti_func=self.acti_func,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
)
flow_t2 = fc_layer(flow_t2)
layer_instances.append((fc_layer, flow_t2))
if is_training:
self._print(layer_instances)
return [task1_out, task2_out]
return layer_instances[layer_id][1]
def create_main_network_graph(self, images, is_training):
layer_instances = []
for layer_iter, layer in enumerate(self.layers):
......@@ -89,14 +144,12 @@ class MT1_VGG16Net(BaseNet):
# last layer
elif layer_iter == len(self.layers)-1:
fc_layer = FullyConnectedLayer(
n_output_chns=layer['n_features'],
acti_func=self.acti_func,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
)
flow = fc_layer(flow)
layer_instances.append((fc_layer, flow))
downsample_layer = DownSampleLayer(
kernel_size=2,
func='MAX',
stride=2)
flow = downsample_layer(flow)
layer_instances.append((downsample_layer, flow))
# all other
else:
......@@ -145,13 +198,10 @@ class MT1_VGG16Net(BaseNet):
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
)
flow = fc_layer(flow, keep_prob=0.5)
flow = fc_layer(flow)
layer_instances.append((fc_layer, flow))
if is_training:
self._print(layer_instances)
return flow
return layer_instances[layer_id][1]
return flow, layer_instances
@staticmethod
def _print(list_of_layers):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment