- 相关论文:A deep graph neural network architecture for modelling spatio-temporal dynamics in resting-state functional MRI data
- 相关repo:github.com/tjiagoM/spa…
- 笔记人:陈亦新
主函数中生成了这样的模型:
model = SpatioTemporalModel(run_cfg=run_cfg,
encoding_model=None
).to(run_cfg['device_run'])
这个SpatioTemporalModel非常的长,和以前解读工程一样,我们只看forward函数就行,下面片段中的注释为我的理解:
class SpatioTemporalModel(nn.Module):
def forward(self, data):
# 这里的三个数据,和我们在上一小节讲解的一致
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
if self.multimodal_size > 0:
xn, x = x[:, :self.multimodal_size], x[:, self.multimodal_size:]
xn = self.multimodal_lin(xn)
xn = self.activation(xn)
xn = self.multimodal_batch(xn)
xn = F.dropout(xn, p=self.dropout, training=self.training)
# Processing temporal part
if self.conv_strategy != ConvStrategy.NONE:
# 这里似乎是吧LSTM也理解为Conv了
if self.conv_strategy == ConvStrategy.LSTM:
# 采用LSTM作为特征提取的方法
x = x.view(-1, self.num_time_length, 1)
# 可以见下面的LSTM-补充1,就是用0初始化LSTM的隐含特征和cell state
h0, c0 = self.init_lstm_hidden(x)
# 可见下面LSTM-补充2,一个LSTM模块
x, (_, _) = self.temporal_conv(x, (h0, c0))
x = x.contiguous()
else:
# 不是LSTM,那么就是卷积策略了。这里卷积策略包含了一般的1D卷积,也包含了TCN的1D卷积模型。可见下方CNN-补充1和TCN-补充1
x = x.view(-1, 1, self.num_time_length)
x = self.temporal_conv(x)
# Concatenating for the final embedding per node
# 这个变量self.size_before_lin_temporal的数值,卷积通道x时间序列长度。这时候卷积通道数已经放大了8倍,时间序列长度已经下采样了4次,变成原来的16分之1了。
x = x.view(x.size()[0], self.size_before_lin_temporal)
# 是一个全连接层,也可能从_get_lin_temporal函数中得到的组件,详情可以看到下面的方法_get_lin_temporal
x = self.lin_temporal(x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
elif self.encoding_strategy == EncodingStrategy.STATS:
# 全连接层self.stats_lin+1D BN层
x = self.stats_lin(x)
x = self.activation(x)
x = self.stats_batch(x)
x = F.dropout(x, p=self.dropout, training=self.training)
elif self.encoding_strategy == EncodingStrategy.VAE3layers:
# 这个也简单,就是VAE自编码器来做的特征提取
mu, logvar = self.encoder_model.encode(x)
x = self.encoder_model.reparameterize(mu, logvar)
elif self.encoding_strategy == EncodingStrategy.AE3layers:
# 和上面类似,是autoENcoder的
x = self.encoder_model.encode(x)
if self.multimodal_size > 0:
x = torch.cat((xn, x), dim=1)
# 到这一步的时候,我们的x是已经从ts当中提取好的特征。
# 图网络用了两个经典中的经典,GAT和GCN。GCN我之前有一篇ISBI的论文用的就是这个,后来就没再看过了。嘎嘎
if self.sweep_type in [SweepType.GAT, SweepType.GCN]:
# 总之,图网络的特征提取,其实和transformer的attention map非常类似。这里在宏观讲述模型结构的时候,暂时先不细讲,之后在仔细的考虑TCN和GNN的代码实现细节。
if self.edge_weights:
# 这个带上edge-weights的概念,也就是会输入两个节点之间的连接的强弱。
x = self.gnn_conv1(x, edge_index, edge_weight=edge_attr.view(-1))
else:
# 没有edgeweights的概念的,则是,仅仅告诉模型这两个节点有连接有关系,但是并不会进一步的去诉说强弱
x = self.gnn_conv1(x, edge_index)
x = self.activation(x)
x = F.dropout(x, training=self.training)
# 看来这里的图网络,也是一个非常浅层的,只有1层或者2层的网络。
if self.num_gnn_layers == 2:
if self.edge_weights:
x = self.gnn_conv2(x, edge_index, edge_weight=edge_attr.view(-1))
else:
x = self.gnn_conv2(x, edge_index)
x = self.activation(x)
x = F.dropout(x, training=self.training)
# 此外,作者还考虑了叫做PNANodeModel的特征提取器
elif self.sweep_type == SweepType.META_NODE:
x = self.meta_layer(x, edge_index, edge_attr)
# 此外,作者还考虑了叫做MetaLayer的特征提取器
elif self.sweep_type == SweepType.META_EDGE_NODE:
x, edge_attr, _ = self.meta_layer(x, edge_index, edge_attr)
# 这里就是和上一章节讲解的graph pool的方式,有平均,相加和DiffPool
if self.pooling == PoolingStrategy.MEAN:
x = global_mean_pool(x, data.batch)
elif self.pooling == PoolingStrategy.ADD:
x = global_add_pool(x, data.batch)
elif self.pooling in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED]:
# 我们还记得上一章遗留了一个问题,就是DiffPool只能处理稠密邻接矩阵,而咱们的是稀疏的。所以转换的方式在这里,可见下面的to_dense_ad部分
adj_tmp = pyg_utils.to_dense_adj(edge_index, data.batch, edge_attr=edge_attr)
if edge_attr is not None: # Because edge_attr only has 1 feature per edge
adj_tmp = adj_tmp[:, :, :, 0]
x_tmp, batch_mask = pyg_utils.to_dense_batch(x, data.batch)
# self.diff_pool就是DiffPool这个组件,下一小节继续细讲
x, link_loss, ent_loss = self.diff_pool(x_tmp, adj_tmp, batch_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.activation(self.pre_final_linear(x))
elif self.pooling == PoolingStrategy.CONCAT:
x, _ = to_dense_batch(x, data.batch)
x = x.view(-1, self.NODE_EMBED_SIZE * self.num_nodes)
x = self.activation(self.pre_final_linear(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.final_linear(x)
if self.final_sigmoid:
return torch.sigmoid(x) if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (
torch.sigmoid(x), link_loss, ent_loss)
else:
return x if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (x, link_loss, ent_loss)
对于上述代码段的补充扩展:
- LSTM-补充1
def init_lstm_hidden(x):
h0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
c0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
return [t.to(x.device) for t in (h0, c0)]
- LSTM-补充2
self.temporal_conv = nn.LSTM(input_size=1,
hidden_size=run_cfg['tcn_hidden_units'],
num_layers=run_cfg['tcn_depth'],
dropout=dropout_perc,
batch_first=True)
- CNN-补充1
stride = 2
padding = 3
self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)
self.conv1d_1 = nn.Conv1d(1, self.channels_conv, 7, padding=padding, stride=stride)
self.conv1d_2 = nn.Conv1d(self.channels_conv, self.channels_conv * 2, 7, padding=padding, stride=stride)
self.conv1d_3 = nn.Conv1d(self.channels_conv * 2, self.channels_conv * 4, 7, padding=padding, stride=stride)
self.conv1d_4 = nn.Conv1d(self.channels_conv * 4, self.channels_conv * 8, 7, padding=padding, stride=stride)
self.batch1 = BatchNorm1d(self.channels_conv)
self.batch2 = BatchNorm1d(self.channels_conv * 2)
self.batch3 = BatchNorm1d(self.channels_conv * 4)
self.batch4 = BatchNorm1d(self.channels_conv * 8)
self.temporal_conv = nn.Sequential(self.conv1d_1, self.activation, self.batch1, nn.Dropout(dropout_perc),
self.conv1d_2, self.activation, self.batch2, nn.Dropout(dropout_perc),
self.conv1d_3, self.activation, self.batch3, nn.Dropout(dropout_perc),
self.conv1d_4, self.activation, self.batch4, nn.Dropout(dropout_perc))
self.init_weights()
- TCN-补充1
#self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
#self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)
if run_cfg['tcn_hidden_units'] == 8:
self.size_before_lin_temporal = self.channels_conv * (2 ** (run_cfg['tcn_depth'] - 1)) * self.num_time_length
else:
self.size_before_lin_temporal = run_cfg['tcn_hidden_units'] * self.num_time_length
self.lin_temporal = self._get_lin_temporal(run_cfg)
tcn_layers = []
for i in range(run_cfg['tcn_depth']):
if run_cfg['tcn_hidden_units'] == 8:
tcn_layers.append(self.channels_conv * (2 ** i) )
else:
tcn_layers.append(run_cfg['tcn_hidden_units'])
self.temporal_conv = TemporalConvNet(1,
tcn_layers,
kernel_size=run_cfg['tcn_kernel'],
dropout=self.dropout,
norm_strategy=run_cfg['tcn_norm_strategy'])
- _get_lin_temporal
def _get_lin_temporal(self, run_cfg):
if run_cfg['tcn_final_transform_layers'] == 1:
lin_temporal = nn.Linear(self.size_before_lin_temporal,
self.NODE_EMBED_SIZE - self.multimodal_size)
elif run_cfg['tcn_final_transform_layers'] == 2:
lin_temporal = nn.Sequential(
nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
self.activation, nn.Dropout(self.dropout),
nn.Linear(int(self.size_before_lin_temporal / 2), self.NODE_EMBED_SIZE - self.multimodal_size))
elif run_cfg['tcn_final_transform_layers'] == 3:
lin_temporal = nn.Sequential(
nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
self.activation, nn.Dropout(self.dropout),
nn.Linear(int(self.size_before_lin_temporal / 2), int(self.size_before_lin_temporal / 3)),
self.activation, nn.Dropout(self.dropout),
nn.Linear(int(self.size_before_lin_temporal / 3), self.NODE_EMBED_SIZE - self.multimodal_size))
return lin_temporal
- to_dense_adj
import torch_geometric.utils as pyg_utils
pyg_utils.to_dense_adj