lhk lhk - 28 days ago 13
Python Question

Keras: Weight sharing doesn't work

I would like to use a ConvNet to segment image data. The same network should be fed with different (but very similar) data, then the outputs should be merged.

There is one trick involved: My data is 3D, I'm slicing it into 2D images and passing them to the ConvNet via

TimeDistributed
.

It is important that the same ConvNet is used for the images, the weights should be shared.
Here is the code:

dim_x, dim_y, dim_z = 40, 40, 40

inputs = Input((1, dim_x, dim_y, dim_z))

# slice the volume along different axes
x_perm=Permute((2,1,3,4))(inputs)
y_perm=Permute((3,1,2,4))(inputs)
z_perm=Permute((4,1,2,3))(inputs)

#apply the segmentation to each layer and for each slice-direction
x_dist=TimeDistributed(convmodel)(x_perm)
y_dist=TimeDistributed(convmodel)(y_perm)
z_dist=TimeDistributed(convmodel)(z_perm)

# now undo the permutation
x_dist=Permute((2,1,3,4))(x_dist)
y_dist=Permute((2,3,1,4))(y_dist)
z_dist=Permute((2,3,4,1))(z_dist)

#now merge the predictions
segmentation=merge([x_dist, y_dist, z_dist], mode="concat")

temp_model=Model(input=inputs, output=segmentation)

temp_model.summary()


The convnet model has about 3.3 million parameters. The permutations and the TimeDistributed layer don't have parameters of their own.
So the complete model should have the same amount of parameters as the convnet.

It doesn't, it has 3 times more parameters, about 9.9 million.

Obviously the weights are not shared.
But this is the way weight sharing is supposed to work.

Does the model share weights and report the number of parameters wrongly ?
Do I have to change the setup to enable weight sharing ?

lhk lhk
Answer

Thanks to the Keras-Users Google Group, this is now answered: https://groups.google.com/forum/#!topic/keras-users/P-BMpdyJfXI

The trick is to create the segmentation-layer first and then apply it to the data. Here is working code:

#apply the segmentation to each layer and for each slice-direction
time_dist=TimeDistributed(convmodel)

x_dist=time_dist(x_perm)
y_dist=time_dist(y_perm)
z_dist=time_dist(z_perm)
Comments