分類任務與分割任務有什麼聯系嗎?
答案是肯定的。
分割其實就是對每一個像素進行分類。在代碼上,分割與分類的差別就更小了,都是用全連接配接層輸出一定的數目,這個數目就是你要分類/分割的個數。
以PointNet為例,先看看網絡架構:

可以看到網絡在得到global feature之前,分類和分割是公用一套網絡的。它們的代碼自然也一樣。這部分代碼位于pointnet_cls.py和pointnet_seg.py中,完全相同。
def get_model(point_cloud, is_training, bn_decay=None):
""" Classification PointNet, input is BxNx3, output Bx40 """
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
end_points = {} with tf.variable_scope('transform_net1') as sc:
transform = input_transform_net(point_cloud, is_training, bn_decay, K=3)
point_cloud_transformed = tf.matmul(point_cloud, transform)
input_image = tf.expand_dims(point_cloud_transformed, -1)
net = tf_util.conv2d(input_image, 64, [1,3],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv1', bn_decay=bn_decay)
net = tf_util.conv2d(net, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv2', bn_decay=bn_decay) with tf.variable_scope('transform_net2') as sc:
transform = feature_transform_net(net, is_training, bn_decay, K=64)
end_points['transform'] = transform
net_transformed = tf.matmul(tf.squeeze(net, axis=[2]), transform)
net_transformed = tf.expand_dims(net_transformed, [2])
net = tf_util.conv2d(net_transformed, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv3', bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv4', bn_decay=bn_decay)
net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv5', bn_decay=bn_decay)
再往後看,就出現一些差別了。
分類任務:
# Symmetric function: max pooling
net = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
net = tf.reshape(net, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
scope='fc1', bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
scope='dp1')
net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
scope='fc2', bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training,
scope='dp2')
net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3') return net, end_points
代碼布局如同網絡中描繪的一樣。池化操作後,做全連接配接層,最後輸出40,對應40類物體分類。
再來看分割:
global_feat = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='maxpool')
print(global_feat)
global_feat_expand = tf.tile(global_feat, [1, num_point, 1, 1])
concat_feat = tf.concat(3, [point_feat, global_feat_expand])
print(concat_feat)
net = tf_util.conv2d(concat_feat, 512, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv6', bn_decay=bn_decay)
net = tf_util.conv2d(net, 256, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv7', bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv8', bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='conv9', bn_decay=bn_decay)
net = tf_util.conv2d(net, 50, [1,1],
padding='VALID', stride=[1,1], activation_fn=None,
scope='conv10')
net = tf.squeeze(net, [2]) # BxNxC
return net, end_points
除了增加全局特征與點特征的拼接外,也是做了全連接配接操作,注意此處的全連接配接使用1*1的卷積實作的,但是本質上和使用fully_connect效果一樣。最後的輸出是50,對應的是分割任務的50個parts。
最後的損失函數也是一樣的。這裡就不貼出來了。
是以,總的來說,分割就是一種特殊的分類。當然,為了提高分割效果,可以對損失函數做相應的改進,如平滑等