天天看點

點雲深度學習系列2——PointNet/PointCNN代碼比較(變換矩陣部分)

PointNet與PointCNN從文章到代碼都有很多相似之處,兩者對比看待,或許更有助于我們了解。

衆所周知,PointNet中使用了maxpooling和T-net,作者文章中起到關鍵作用的是maxpooling,而T-net對性能的提升作用也還是有的(兩個T-net加上regularization 貢獻了2.1個百分點),但奇怪的是在PointNet++的代碼中,已經看不到T-net了(這一點論文沒有提及,github上也有人提問,但是作者沒有回複)。

但是,與之相似的PointCNN中有個X變換矩陣,但X變換對于PointCNN的作用可就非常重要了,因為它連maxpooling都沒有用。下面我們就對兩者進行比較。

首先是PointNet中的T-net代碼:

def feature_transform_net(inputs, is_training, bn_decay=None, K=64):
    """ Feature Transform Net, input is BxNx1xK
        Return:
            Transformation matrix of size KxK """
    batch_size = inputs.get_shape()[0].value
    num_point = inputs.get_shape()[1].value

    net = tf_util.conv2d(inputs, 64, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='tconv1', bn_decay=bn_decay)
    net = tf_util.conv2d(net, 128, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='tconv2', bn_decay=bn_decay)
    net = tf_util.conv2d(net, 1024, [1,1],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='tconv3', bn_decay=bn_decay)
    net = tf_util.max_pool2d(net, [num_point,1],#池化視窗是[num_point,1]
                             padding='VALID', scope='tmaxpool')

    net = tf.reshape(net, [batch_size, -1])#變成兩維
    net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
                                  scope='tfc1', bn_decay=bn_decay)
    net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
                                  scope='tfc2', bn_decay=bn_decay)

    with tf.variable_scope('transform_feat') as sc:
        weights = tf.get_variable('weights', [256, K*K],
                                  initializer=tf.constant_initializer(0.0),
                                  dtype=tf.float32)
        biases = tf.get_variable('biases', [K*K],
                                 initializer=tf.constant_initializer(0.0),
                                 dtype=tf.float32)
        biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32)
        transform = tf.matmul(net, weights)
        transform = tf.nn.bias_add(transform, biases)

    transform = tf.reshape(transform, [batch_size, K, K])
    return transform
           

代碼主題部分,前三個conv2d用來升維,接着一個max_pool2d把1024個點的特征做了maxpooling,融合成一點。然後跟兩個fully_connected把次元降到256,再然後是跟[256, K*K]的權值相乘再加K*K維的偏移,達到[batch_size, K*K],最後變形成[batch_size, K, K],大功告成,不容易啊。

接下來看PointCNN的X變換:

######################## X-transformation #########################
        X_0 = pf.conv2d(nn_pts_local_bn, K * K, tag + 'X_0', is_training, (1, K), with_bn=False)
        #kernal size(1, K, 3), kernal num=K*K, so the output size is (N, P, 1, K*K). so this operator is in the neighbor point dimentional.
        X_1 = pf.dense(X_0, K * K, tag + 'X_1', is_training, with_bn=False)#in the center point dimensional ,P decrease to 1.
        X_2 = pf.dense(X_1, K * K, tag + 'X_2', is_training, with_bn=False, activation=None)#(N, P, 1, K*K)
        X = tf.reshape(X_2, (N, P, K, K), name=tag + 'X')
        fts_X = tf.matmul(X, nn_fts_input, name=tag + 'fts_X')
        ###################################################################
           

第一層是卷積層,讓人很吃驚,卷積核是1*k的,也就是在鄰域次元上,直接把k個鄰域點彙聚到一個點上,且用了K×K個卷積層,把特征次元升高到k*k,次元從(P,K,C)變成了(P,1,K×K);然後,作者用了兩個dense層,保持了這個結構;最後reshape成(P,K,K),這就得到了X-transporm矩陣。

從體量和複雜程度上來看,後者勝出。

從作用效果來看,不太好評價。因為PointCNN是有局部特征的,這點和pointnet++思想一緻。是以即便PointCNN性能超過了PointNet,也不能直接證明X-transporm就一定優于T-net了。

代碼方面,其實T-net的前四層和X變換的第一層做的事情差不多,都是為了把多個點的特征融合到一組特征,為訓練變換矩陣提供素材。但接下來就不同了,T-net隻有一組K*K的weights權值,而PointCNN後面跟了兩個dense層,次元都是K*K的,參數更多,是以猜測PointCNN訓練變換矩陣應該會更加充分。後期可以通過實驗驗證以下。

最後歪個樓,我在測試PointNet++的代碼時,ModelNet40的分類結果一直徘徊在90.1左右,達不到論文裡提的90.7,跟作者郵件聯系也沒有得到很好的答案。是以我再想是不是作者本來的PointNet++代碼裡是有T-net的,但是放到github裡的版本沒加上。但這隻是猜測,有待驗證。

附上相關代碼的連結:

PointNet2:https://github.com/charlesq34/pointnet2

PointCNN:https://github.com/yangyanli/PointCNN

繼續閱讀