天天看點

AI+無線通信——Top7 (Baseline)總結

比賽已經告一段落,現在我們隊兌現承諾,将比賽方案開源給大家,互勉互助,共同進步。

隊伍介紹

     我們的隊伍名是Baseline,我們因分享Baseline結緣,是以就把隊伍名叫Baseline。

     隊長: 方曦來自上海交通大學,研三。

     隊員 :呂曉欣來自網易,AI工程師

     隊員:王浩來自北京星河亮點,軟體研發

     隊員:楊新達來自廣州一家企業,AI工程師

方案

摘要

對于目前通信系統來說,實體層是通信服務得到保障的基礎;而對于實體層來說,MIMO則是基本的支撐技術;對于MIMO來說,準确地确定信道品質并做有效回報及利用又是必不可少的關鍵問題。

在國際标準化組織3GPP的讨論内,目前這部分工作是通過CSI 參考信号設計及CSI回報機制完成。在目前的CSI回報設計中,主要是依賴矢量量化、碼本設計的方式來實作信道特征的提取與回報,例如基于TYPE1、TYPE2的CSI回報設計等。在目前的實踐來看,這類回報方式是有效的,但是由于其核心思想是基于資訊抽取、碼本回報的方式,其所回報的目标資訊實際上是有損信道資訊。

在本次大賽中,我們從計算機視覺角度模組化,設計出一種基于CNN的自編碼器結構。我們采用了帶有SE結構的BCSP子產品作為網絡的基礎元件,在計算效率和網絡精度上都有較好的效果;采用帶有誤差恢複能力的量化子產品,一方面能降低量化誤差,同時也能提高編碼器的訓練效果;通過分析大賽資料,我們利用Fast-AutoAugment思路找到了4種資料增強方法,完美的解決了在384 附近bit數的網絡過拟合問題;我們利用剪枝和降低量化精度的方式,大幅度加速了我們的訓練過程。最終我們獲得了第7名的好成績。

關鍵詞

無線通信, 信道回報, 卷積神經網絡, 注意力機制, 資料增強

1 注意力機制的使用

在我們采用的注意力機制是SE-Net: Squeeze-and-Excitation Networks,簡稱SE-Net[1],它赢得了最後一屆ImageNet2017競賽分類任務的冠軍,其基本原理是對于每個輸出channel,預測一個常數權重,對每個channel權重一下。結構如下圖:

AI+無線通信——Top7 (Baseline)總結

圖1:SE 注意力機制

第一步每個通道H*W個數全局平均池化得到一個标量,稱之為Squeeze,然後兩個FC得到01之間的一個權重值,對原始的每個HxW的每個元素乘以對應通道的權重,得到新的feature map,稱之為Excitation。任意的原始網絡結構,都可以通過這個Squeeze-Excitation的方式進行feature recalibration,如下圖。

AI+無線通信——Top7 (Baseline)總結

圖2:SENet基礎結構

具體實作上就是一個Global Average Pooling-FC-ReLU-FC-Sigmoid,第一層的FC會把通道降下來,然後第二層FC再把通道升上去,得到和通道數相同的C個權重,每個權重用于給對應的一個通道進行權重。上圖中的r就是縮減系數,實驗确定選取16,可以得到較好的性能并且計算量相對較小。SENet的核心思想在于通過網絡根據loss去學習特征權重,使得有效的feature map權重大,無效或效果小的feature map權重小的方式訓練模型達到更好的結果。

我們将SENet一個子結構,嵌入到C3和BottleneckCSP子產品的最後一層。如圖3所示。

在該賽題中SE結構或者說注意力機制能夠大幅度提升模型的拟合能力,讓我們的模型能夠成功的完成432bit 達标,但随之而來的模型過拟合現象困擾了我們很長時間。

2 量化誤差恢複子產品

在量化編碼過程中,經過量化-反量化操作,将會讓原始編碼丢失一部分資訊,即量化誤差,量化誤差的存在不僅使得模型最終NMSE會比無量化操作的更高,還會減慢decoder的訓練速度和效果。故我們提出量化誤差恢複子產品,即對反量化後的編碼進行refine,使之更加接近無量化損失。

AI+無線通信——Top7 (Baseline)總結

圖3:基礎子產品圖

具體操作是,我們對量化-反量化後的編碼,通過兩層全連接配接(帶有bn和非線性層),并經過标準化處理,得到值域與量化誤內插補點域相同的輸出(通過sigmoid以及scale等操作調整值域為[-12B+1,12B+1])以殘差的方式加到原始反量化編碼之後,以起到恢複量化誤差的效果。同時,為了使得這個子產品能更好地按設想工作,我們對此子產品的輸出增加了一路損失函數,使得恢複後的編碼與量化前的編碼更加接近。

設量化前編碼為X,量化後編碼為X’,我們的誤差恢複子產品為R,則額外監督表示為如下:L(X+R(X'), X)。

AI+無線通信——Top7 (Baseline)總結
AI+無線通信——Top7 (Baseline)總結

圖4:誤差恢複子產品與Simsiam結構對比圖

上圖同對比了誤差恢複子產品和自監督算法中的SimSiam[2]結構對比圖。如果我們把量化誤差看座位一種資料增強,那麼量化誤差修複這一塊剛好可以看作一種自監督學習網絡,同時存在一條支路在反向傳播過程中能夠計算出準确的梯度,讓我們能夠獲得更好的encoder層。

3 資料增強

賽方提供的資料200*3000是按序擺放的,通過分析,我們發現3000這個次元中各個資料似乎存在一些相似關系,在我們全部的資料增強過程中都不會去破壞這種模式。資料增強固然可以一定程度上緩解模型過拟合,但是如果設計的不得當,網絡會學到很多沒用的資訊進而不能訓到很低的nmse,為此我們借鑒了Fast-AutoAugment中的思想,對于每一種資料增強,我們利用原始資料訓練的模型在驗證集資料+該資料增強統計nmse,如果nmse過高,那麼這種資料增強大機率改變了原始資料分布,不應該背采納。通過這種方式我們選取了4中資料增強方法:

1-X

實部虛部shuffle

MixUp

CutMix

傳統的MixUp和CutMix會破壞資料原有的模式,是以我們對其進行一些改造。在樣本采樣過程中,我們隻會選擇屬于同一種patten的兩個樣本進行融合;我們不會去破壞16這個次元的數值關系,是以CutMix過程中随機選擇24行中的一部分進行替換,這是因為24這個次元雖然有patten,但是似乎不存在特别明顯的數值關系。通過這種方式我們能偶成功的訓練出384bit的模型。

AI+無線通信——Top7 (Baseline)總結
圖5:資料增強效果圖
AI+無線通信——Top7 (Baseline)總結

5 剪枝與量化

量化層我們選擇了簡單的均勻量化操作。量化bit數目選取上,考慮到任務更加側重更小傳輸bit而不是極緻的精度(低NMSE),故可以選擇使用更小的量化bit數目,而太小的量化bit數目會導緻量化誤差過大,使得decoder訓練更加困難也更容易過拟合。權衡上述,我們選擇了使用Bit=3的量化操作。

訓練初始模型時,我們首先選擇使用bit數為432的bitstream構模組化型進行訓練,訓練完成之後對encoder最末層全連接配接和decoder最前層全連接配接進行裁剪,得到384bit的autoencoder模型,然後進行進一步finetune,得到384bit模型(3bit*128)。在比賽的最後階段,我們選擇對128個code中的6個,量化bit從3bit壓縮到2bit量化,進一步finetune,得到最終送出的378bit模型。即最終送出的378bit模型中,有122個code采用3bit量化編碼,6個code采用2bit量化。

緻謝

感謝主辦方提供資料,感謝DataFountain平台提供支援和及時的問題回報!

Code

     modelDesign.py

# =======================================================================================================================
# =======================================================================================================================
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from collections import OrderedDict
 
NUM_FEEDBACK_BITS_STARTS = 768
NUM_FEEDBACK_BITS = 384  # pytorch版本一定要有這個參數
channel_last = 1
CR_dim = 128
REFINEMENT = 1
 
 
class Mish(torch.nn.Module):
    def __init__(self):
        super().__init__()
 
    def forward(self, x):
        x = x * (torch.tanh(torch.nn.functional.softplus(x)))
        return x
 
 
ACT = nn.SiLU()
 
 
# =======================================================================================================================
# =======================================================================================================================
# Number to Bit Defining Function Defining
def Num2Bit(Num, B):
    Num_ = Num.type(torch.uint8)
 
    def integer2bit(integer, num_bits=B * 2):
        dtype = integer.type()
        exponent_bits = -torch.arange(-(num_bits - 1), 1).type(dtype)
        exponent_bits = exponent_bits.repeat(integer.shape + (1,))
        out = integer.unsqueeze(-1) // 2 ** exponent_bits
        return (out - (out % 1)) % 2
 
    bit = integer2bit(Num_)
    bit = (bit[:, :, B:]).reshape(-1, Num_.shape[1] * B)
    return bit.type(torch.float32)
 
 
def Bit2Num(Bit, B):
    Bit_ = Bit.type(torch.float32)
    Bit_ = torch.reshape(Bit_, [-1, int(Bit_.shape[1] / B), B])
    num = torch.zeros(Bit_[:, :, 1].shape).cuda()
    for i in range(B):
        num = num + Bit_[:, :, i] * 2 ** (B - 1 - i)
    return num
 
 
# =======================================================================================================================
# =======================================================================================================================
# Quantization and Dequantization Layers Defining
class Quantization(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, B):
        ctx.constant = B
        step = 2 ** B
        out = torch.round(x * step - 0.5)
        out = Num2Bit(out, B)
        return out
 
    @staticmethod
    def backward(ctx, grad_output):
        # return as many input gradients as there were arguments.
        # Gradients of constant arguments to forward must be None.
        # Gradient of a number is the sum of its B bits.
        b, _ = grad_output.shape
        grad_num = torch.sum(grad_output.reshape(b, -1, ctx.constant), dim=2) / ctx.constant
        return grad_num, None
 
 
class Dequantization(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, B):
        ctx.constant = B
        step = 2 ** B
        out = Bit2Num(x, B)
        out = (out + 0.5) / step
        return out
 
    @staticmethod
    def backward(ctx, grad_output):
        # return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        # repeat the gradient of a Num for B time.
        b, c = grad_output.shape
        grad_output = grad_output.unsqueeze(2) / ctx.constant
        grad_bit = grad_output.expand(b, c, ctx.constant)
        return torch.reshape(grad_bit, (-1, c * ctx.constant)), None
 
 
class QuantizationLayer(nn.Module):
    def __init__(self, B):
        super(QuantizationLayer, self).__init__()
        self.B = B
 
    def forward(self, x):
        out = Quantization.apply(x, self.B)
        return out
 
 
class DequantizationLayer(nn.Module):
    def __init__(self, B):
        super(DequantizationLayer, self).__init__()
        self.B = B
 
    def forward(self, x):
        out = Dequantization.apply(x, self.B)
        return out
 
 
# =======================================================================================================================
# =======================================================================================================================
# Encoder and Decoder Class Defining
def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p
 
 
class SEBlock(nn.Module):
 
    def __init__(self, input_channels, internal_neurons):
        super(SEBlock, self).__init__()
        self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1,
                              bias=True, padding_mode='circular')
        self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1,
                            bias=True, padding_mode='circular')
 
    def forward(self, inputs):
        x = F.avg_pool2d(inputs, kernel_size=inputs.size(3))
        x = self.down(x)
        x = F.leaky_relu(x)
        x = self.up(x)
        x = torch.sigmoid(x)
        x = x.repeat(1, 1, inputs.size(2), inputs.size(3))
        return inputs * x
 
 
class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = ACT
 
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
 
 
class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super(Bottleneck, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2
 
    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
 
 
class BottleneckCSP(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(BottleneckCSP, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
        self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
        self.cv4 = Conv(2 * c_, c2, 1, 1)
        self.bn = nn.BatchNorm2d(2 * c_)  # applied to cat(cv2, cv3)
        self.act = nn.LeakyReLU(0.1, inplace=True)
        self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
        self.att = SEBlock(c2, c2 // 2)
 
    def forward(self, x):
        y1 = self.cv3(self.m(self.cv1(x)))
        y2 = self.cv2(x)
        return self.att(self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))))
 
 
class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(C3, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
        self.att = SEBlock(c2, c2 // 2)
 
    def forward(self, x):
        return self.att(self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)))
 
 
class Focus(nn.Module):
    # Focus wh information into c-space
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super(Focus, self).__init__()
        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
        # self.contract = Contract(gain=2)
 
    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
        return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
        # return self.conv(self.contract(x))
 
 
class Contract(nn.Module):
    # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
    def __init__(self, gain=2):
        super().__init__()
        self.gain = gain
 
    def forward(self, x):
        N, C, H, W = x.size()  # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
        s = self.gain
        x = x.view(N, C, H // s, s, W // s, s)  # x(1,64,40,2,40,2)
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # x(1,2,2,64,40,40)
        return x.view(N, C * s * s, H // s, W // s)  # x(1,256,40,40)
 
 
class Expand(nn.Module):
    # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
    def __init__(self, c1, c2, gain=2, k=1, s=1, p=None, g=1, act=True):
        super().__init__()
        self.gain = gain
        self.conv = Conv(c1 // 4, c2, k, s, p, g, act)
 
    def forward(self, x):
        N, C, H, W = x.size()  # assert C / s ** 2 == 0, 'Indivisible gain'
        s = self.gain
        x = x.view(N, s, s, C // s ** 2, H, W)  # x(1,2,2,16,80,80)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # x(1,16,80,2,80,2)
        return self.conv(x.view(N, C // s ** 2, H * s, W * s))  # x(1,16,160,160)
 
 
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=True)
 
 
class WLBlock(nn.Module):
    def __init__(self, paths, in_c, k=16, n=[1, 1], e=[1.0, 1.0], quantization=True):
 
        super(WLBlock, self).__init__()
        self.paths = paths
        self.n = n
        self.e = e
        self.k = k
        self.in_c = in_c
        for i in range(self.paths):
            self.__setattr__(str(i), nn.Sequential(OrderedDict([
                ("Conv0", Conv(self.in_c, self.k, 3)),
                ("BCSP_1", BottleneckCSP(self.k, self.k, n=self.n[i], e=self.e[i])),
                ("C3_1", C3(self.k, self.k, n=self.n[i], e=self.n[i])),
                ("Conv1", Conv(self.k, self.k, 3)),
            ])))
        self.conv1 = conv3x3(self.k * self.paths, self.k)
 
    def forward(self, x):
        outs = []
        for i in range(self.paths):
            _ = self.__getattr__(str(i))(x)
            outs.append(_)
        out = torch.cat(tuple(outs), dim=1)
        out = self.conv1(out)
        out = out + x if self.in_c == self.k else out
        return out
 
 
class Encoder(nn.Module):
    B = 3
 
    def __init__(self, feedback_bits, quantization=True):
        super(Encoder, self).__init__()
        self.feedback_bits = feedback_bits
        self.k = 256
        self.encoder1 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 16, 5)),
            ("BCSP_1", BottleneckCSP(16, 16, n=2, e=0.5)),
            ("C3_1", C3(16, 16, n=1, e=2.0)),
            ("Conv1", Conv(16, self.k, 3))
        ]))
        self.encoder2 = nn.Sequential(OrderedDict([
            ("Focus0", Focus(2, 16)),
            ("BCSP_1", BottleneckCSP(16, 16, n=1, e=1.0)),
            ("C3_1", C3(16, 16, n=2, e=2.0)),
            ("Expand0", Expand(16, 16)),
            ("Conv1", Conv(16, self.k, 3))
        ]))
        self.encoder3 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 32, 3)),
            ("WLBlock1", WLBlock(3, 32, 32, [1, 2, 3], [0.5, 1, 1.5])),
            ("WLBlock2", WLBlock(2, 32, 32, [2, 4], [1, 2])),
            ("Conv1", Conv(32, self.k, 3)),
        ]))
        self.encoder_conv = nn.Sequential(OrderedDict([
            ("conv1x1", Conv(self.k * 3, 2, 1)),
        ]))
        self.fc = nn.Linear(768, int(NUM_FEEDBACK_BITS_STARTS / self.B))
        self.dim_verify = nn.Linear(int(NUM_FEEDBACK_BITS_STARTS / self.B), int(self.feedback_bits / self.B))
 
        self.sig = nn.Sigmoid()
        self.quantize = QuantizationLayer(self.B)
        self.quantization = quantization
 
    def forward(self, x):
        if channel_last:
            x = x.permute(0, 3, 1, 2).contiguous()
        x0 = x.view(-1, 768)
        encoder1 = self.encoder1(x)
        encoder2 = self.encoder2(x)
        encoder3 = self.encoder3(x)
        out = torch.cat((encoder1, encoder2, encoder3), dim=1)
        out = self.encoder_conv(out)
        out = out.view(-1, 768) + x0
        out = self.fc(out)
        out = self.dim_verify(out)
        out = self.sig(out)
        enq_data = out
        if self.quantization:
            out = self.quantize(out)
        elif self.quantization == 'check':
            out = out
        else:
            out = self.fake_quantize(out)
        return out, enq_data
 
 
class Decoder(nn.Module):
    B = 3
 
    def __init__(self, feedback_bits, quantization=True):
        super(Decoder, self).__init__()
        self.k = 64
        self.feedback_bits = feedback_bits
        self.dequantize = DequantizationLayer(self.B)
        self.dim_verify = nn.Linear(int(self.feedback_bits / self.B), int(NUM_FEEDBACK_BITS_STARTS / self.B))
        self.fc = nn.Linear(int(NUM_FEEDBACK_BITS_STARTS / self.B), 768)
        self.ende_refinement = nn.Sequential(
            nn.Linear(int(self.feedback_bits / self.B), int(self.feedback_bits / self.B)),
            nn.BatchNorm1d(int(self.feedback_bits / self.B)),
            nn.ReLU(True),
            nn.Linear(int(self.feedback_bits / self.B), int(self.feedback_bits / self.B), bias=False),
            nn.Sigmoid(),
        )
        self.decoder1 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 16, 3)),
            ("BCSP_1", BottleneckCSP(16, 16, n=1, e=1.0)),
            ("Conv1", Conv(16, self.k, 1)),
        ]))
        self.decoder2 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 32, 5)),
            ("BCSP_1", BottleneckCSP(32, 32, n=4, e=2.0)),
            ("Conv1", Conv(32, self.k, 1)),
        ]))
        self.decoder3 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 32, (1, 3))),
            ("BCSP_1", BottleneckCSP(32, 32, n=4, e=2.0)),
            ("Conv1", Conv(32, self.k, 1)),
        ]))
        self.decoder4 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 32, (3, 1))),
            ("BCSP_1", BottleneckCSP(32, 32, n=4, e=2.0)),
            ("Conv1", Conv(32, self.k, 1)),
        ]))
        self.decoder5 = nn.Sequential(OrderedDict([
            ("Focus0", Focus(2, self.k)),
            ("WLBlock1", WLBlock(3, self.k, self.k, [1, 2, 3], [0.5, 1, 1.5])),
            ("WLBlock2", WLBlock(2, self.k, self.k, [2, 4], [1, 2])),
            ("Expand0", Expand(self.k, self.k)),
            ("Conv1", Conv(self.k, self.k, 1)),
        ]))
        self.decoder6 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 32, (3, 5))),
            ("BCSP_1", BottleneckCSP(32, 32, n=4, e=2.0)),
            ("Conv1", Conv(32, self.k, 5)),
        ]))
        self.decoder7 = nn.Sequential(OrderedDict([
            ("Conv0", Conv(2, 32, (5, 3))),
            ("BCSP_1", BottleneckCSP(32, 32, n=4, e=2.0)),
            ("Conv1", Conv(32, self.k, 3)),
        ]))
        self.decoder8 = nn.Sequential(OrderedDict([
            ("Focus0", Focus(2, self.k, 5)),
            ("WLBlock1", WLBlock(2, self.k, self.k, [1, 2], [0.5, 1])),
            ("WLBlock2", WLBlock(2, self.k, self.k, [1, 2], [1, 0.5])),
            ("Expand0", Expand(self.k, self.k)),
            ("Conv1", Conv(self.k, self.k, 5)),
        ]))
        if REFINEMENT:
            self.refinemodel = nn.Sequential(OrderedDict([
                ("Conv0", Conv(2, 64, 3)),
                ("WLBlock1", WLBlock(3, 64, 64, [1, 2, 3], [0.5, 1, 1.5])),
                ("WLBlock2", WLBlock(2, 64, 64, [2, 4], [1, 2])),
                ("WLBlock3", WLBlock(2, 64, 64, [2, 4], [1, 2])),
                ("WLBlock4", WLBlock(2, 64, 64, [1, 3], [1, 2])),
                ("Conv1", Conv(64, 2, 3)),
            ]))
        self.decoder_conv = conv3x3(self.k * 8, 2)
        self.sig = nn.Sigmoid()
        self.quantization = quantization
 
    def forward(self, x):
        if self.quantization:
            out = self.dequantize(x)
        else:
            out = x
        out = out.view(-1, int(self.feedback_bits / self.B))
        out_error = self.ende_refinement(out)
        out = out + out_error - 0.5
        deq_data = out
        out = self.dim_verify(out)
 
        out = self.sig(self.fc(out))
        out = out.view(-1, 2, 24, 16)
        out0 = out
        out1 = self.decoder1(out)
        out2 = self.decoder2(out)
        out3 = self.decoder3(out)
        out4 = self.decoder4(out)
        out5 = self.decoder5(out)
        out6 = self.decoder6(out)
        out7 = self.decoder7(out)
        out8 = self.decoder8(out)
        out = torch.cat((out1, out2, out3, out4, out5, out6, out7, out8), dim=1)
        out = self.decoder_conv(out) + out0
        out = self.sig(out)
        if REFINEMENT:
            out = self.sig(self.refinemodel(out)) - 0.5 + out
        if channel_last:
            out = out.permute(0, 2, 3, 1)
        return out, deq_data
 
 
class AutoEncoder(nn.Module):
    def __init__(self, feedback_bits):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(feedback_bits)
        self.decoder = Decoder(feedback_bits)
 
    def forward(self, x):
        feature, enq_data = self.encoder(x)
        out, deq_data = self.decoder(feature)
        return out, feature, enq_data, deq_data
 
 
# =======================================================================================================================
# =======================================================================================================================
# NMSE Function Defining
def NMSE(x, x_hat):
    x_real = np.reshape(x[:, :, :, 0], (len(x), -1))
    x_imag = np.reshape(x[:, :, :, 1], (len(x), -1))
    x_hat_real = np.reshape(x_hat[:, :, :, 0], (len(x_hat), -1))
    x_hat_imag = np.reshape(x_hat[:, :, :, 1], (len(x_hat), -1))
    x_C = x_real - 0.5 + 1j * (x_imag - 0.5)
    x_hat_C = x_hat_real - 0.5 + 1j * (x_hat_imag - 0.5)
    power = np.sum(abs(x_C) ** 2, axis=1)
    mse = np.sum(abs(x_C - x_hat_C) ** 2, axis=1)
    nmse = np.mean(mse / power)
    return nmse
 
 
def Score(NMSE):
    score = 1 - NMSE
    return score
 
 
def NMSE_cuda(x, x_hat):
    x_real = x[:, 0, :, :].view(len(x), -1) - 0.5
    x_imag = x[:, 1, :, :].view(len(x), -1) - 0.5
    x_hat_real = x_hat[:, 0, :, :].view(len(x_hat), -1) - 0.5
    x_hat_imag = x_hat[:, 1, :, :].view(len(x_hat), -1) - 0.5
    power = torch.sum(x_real ** 2 + x_imag ** 2, axis=1)
    mse = torch.sum((x_real - x_hat_real) ** 2 + (x_imag - x_hat_imag) ** 2, axis=1)
    nmse = mse / power
    return nmse
 
 
class NMSELoss(nn.Module):
    def __init__(self, reduction='sum'):
        super(NMSELoss, self).__init__()
        self.reduction = reduction
 
    def forward(self, x_hat, x):
        nmse = NMSE_cuda(x, x_hat)
        if self.reduction == 'mean':
            nmse = torch.mean(nmse)
        else:
            nmse = torch.sum(nmse)
        return nmse
 
 
# =======================================================================================================================
# =======================================================================================================================
import random
 
 
# Data Loader Class Defining
class DatasetFolder(Dataset):
    def __init__(self, matData, phase='val'):
        self.matdata = matData
        self.phase = phase
 
    def __getitem__(self, index):
        y = self.matdata[index]
        if self.phase == 'train' and random.random() < -0.5:
            y = y[::-1, :, :].copy()
        if self.phase == 'train' and random.random() < 0.5:
            y = y[:, ::-1, :].copy()
        if self.phase == 'train' and random.random() < 0.5:
            y = 1 - self.matdata[index]  # 資料中存在類似正交的關系
        if self.phase == 'train' and random.random() < 0.5:
            _ = y
            _[:, :, 0] = y[:, :, 1]
            _[:, :, 1] = y[:, :, 0]
            y = _  # 不同時刻資料實虛存在部分相等的情況
        if self.phase == 'train' and random.random() < 0.5:
            index_ = random.randint(0, self.matdata.shape[0] // 3000 - 1) * 3000 + index % 3000
            p = random.random()
            rows = max(int(24 * p), 1)
            _rows = [i for i in range(24)]
            random.shuffle(_rows)
            _rows = _rows[:rows]
            if random.random() < 0.7:
                y[_rows] = self.matdata[index_][_rows]  # 不同采樣點按行合并,保持采樣點獨有特性,減輕模型對24那個次元的依賴
            else:
                y = (1 - p * 0.2) * y + (p * 0.2) * self.matdata[index_]  # 增加數值擾動,保持采樣點獨有特性
        return y
 
    def __len__(self):
        return self.matdata.shape[0]      

 modelTrain.py

#=======================================================================================================================
#=======================================================================================================================
import numpy as np
import torch
from modelDesign import AutoEncoder,DatasetFolder,NUM_FEEDBACK_BITS,NUM_FEEDBACK_BITS_STARTS,NMSELoss,channel_last #*
import os
import torch.nn as nn
import scipy.io as sio
import random
from torch.cuda.amp import autocast, GradScaler
def NMSE_cuda1(x, x_hat):
    x_real = x[:, :, :, 0].view(len(x),-1) - 0.5
    x_imag = x[:, :, :, 1].view(len(x),-1) - 0.5
    x_hat_real = x_hat[:, :, :, 0].view(len(x_hat), -1) - 0.5
    x_hat_imag = x_hat[:, :, :, 1].view(len(x_hat), -1) - 0.5
    power = torch.sum(x_real**2 + x_imag**2, axis=1)
    mse = torch.sum((x_real-x_hat_real)**2 + (x_imag-x_hat_imag)**2, axis=1)
    nmse = mse/power
    return nmse
    
class NMSELoss1(nn.Module):
    def __init__(self, reduction='sum'):
        super(NMSELoss1, self).__init__()
        self.reduction = reduction
 
    def forward(self, x_hat, x):
        nmse = NMSE_cuda1(x, x_hat)
        if self.reduction == 'mean':
            nmse = torch.mean(nmse) 
        else:
            nmse = torch.sum(nmse)
        return nmse
#=======================================================================================================================
#=======================================================================================================================
# Parameters Setting for Data
CHANNEL_SHAPE_DIM1 = 24
CHANNEL_SHAPE_DIM2 = 16
CHANNEL_SHAPE_DIM3 = 2
# Parameters Setting for Training
BATCH_SIZE = 64
EPOCHS = 1000
LEARNING_RATE = 1e-5
PRINT_RREQ = 100
#NUM_FEEDBACK_BITS =NUM_FEEDBACK_BITS_3
torch.manual_seed(42)
random.seed(42)
#=======================================================================================================================
#=======================================================================================================================
def load_pretrained_weights(model,model_path):
    encoder_pretrained = torch.load(model_path)['state_dict']
    model_dict = model.state_dict()
    #pretrained_weights ={k:v for k,v in encoder_pretrained.items() if (k in model_dict and 'dim_verify' not in k and 'ende_refinement' not in k and 'fc' not in k)}
    pretrained_weights ={k:v for k,v in encoder_pretrained.items() if (k in model_dict )}
    # prune dim_verify layer
    if 0 and NUM_FEEDBACK_BITS != NUM_FEEDBACK_BITS_STARTS:
        w = encoder_pretrained['dim_verify.weight']
        b = encoder_pretrained['dim_verify.bias']
        if  model_dict['dim_verify.weight'].shape[0] != encoder_pretrained['dim_verify.weight'].shape[0]:
            dim = -1
            bits_num =model_dict['dim_verify.weight'].shape[0]
            long = encoder_pretrained['dim_verify.weight'].shape[0]
        else: 
            dim = 0
            bits_num =model_dict['dim_verify.weight'].shape[1]
            long = encoder_pretrained['dim_verify.weight'].shape[1]
         
        #importance = abs(w).sum(dim)
        #sorted_index = torch.argsort(-1*importance) # descend
        start = (long -bits_num)//2
        end = bits_num + (long - bits_num)//2
        if dim == -1:
            pretrained_weights['dim_verify.weight'] = w[start:end,:]
        else:
            pretrained_weights['dim_verify.weight'] = w[:,start:end]
    model_dict.update(pretrained_weights)
    model.load_state_dict(model_dict)
    return model
# Model Constructing
autoencoderModel = AutoEncoder(NUM_FEEDBACK_BITS)
# model_path = './modelSubmit/encoder.pth.tar'
# autoencoderModel.encoder =load_pretrained_weights(autoencoderModel.encoder,model_path)
# model_path = './modelSubmitTeacher/decoder.pth.tar'
# autoencoderModel.decoder =load_pretrained_weights(autoencoderModel.decoder,model_path)
model_path = './modelSubmit/encoder.pth.tar'   
autoencoderModel.encoder.load_state_dict(torch.load(model_path)['state_dict'])
model_path = './modelSubmit/decoder.pth.tar'
autoencoderModel.decoder.load_state_dict(torch.load(model_path)['state_dict'])
 
 
#=======================================================================================================================
#=======================================================================================================================
# Data Loading
mat = sio.loadmat('channelData/H_4T4R.mat')
data = mat['H_4T4R']
data = data.astype('float32')
data = np.reshape(data, (-1, CHANNEL_SHAPE_DIM1, CHANNEL_SHAPE_DIM2, CHANNEL_SHAPE_DIM3))
if not channel_last:
  data = np.transpose(data, (0, 3, 1, 2))
#random.shuffle(data)
split = int(data.shape[0] * 0.95)
data_train0, data_test = data[:split], data[split:]
random.shuffle(data_train0)
split = int(data_train0.shape[0]*0.95)
data_train, data_val = data_train0[:split],data_train0[split:]
train_dataset = DatasetFolder(data_train0,'train')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_dataset = DatasetFolder(data_val,'val')
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
test_dataset = DatasetFolder(data_test,'val')
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
#=======================================================================================================================
#=======================================================================================================================
 
#autoencoderModel = autoencoderModel.cuda()
autoencoderModel = torch.nn.DataParallel(autoencoderModel.cuda())
ctl = NMSELoss1(reduction='mean') if channel_last else NMSELoss(reduction='mean')
criterion = ctl #nn.MSELoss()
criterion_test = ctl
feature_criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(autoencoderModel.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=1e-9, last_epoch=-1)
#=======================================================================================================================
#=======================================================================================================================
# Model Training and Saving
bestLoss = 0.105
valLoss = 1e-5
for epoch in range(EPOCHS):
    scaler = GradScaler()
    print('lr:',optimizer.param_groups[0]['lr'])
    autoencoderModel.train()
    for i, autoencoderInput in enumerate(train_loader):
        autoencoderInput = autoencoderInput.cuda()
        with autocast():
            autoencoderOutput,_, enq, deq = autoencoderModel(autoencoderInput)
            loss1 = criterion(autoencoderOutput, autoencoderInput)
            loss2 = feature_criterion(enq, deq)
            loss = loss1+0*loss2
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        if i % PRINT_RREQ == 0:
            
            print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}\t,Loss_nmse {loss_nmse:.4f}\t,Loss ende {loss_q:.4f}\t'.format(epoch, 
            i, len(train_loader), loss=loss.item(),loss_nmse=loss1.item(),loss_q=loss2.item()))
        # if (i+1) % (4*PRINT_RREQ) == 0:
        #     break
 
    # Model Evaluating
    autoencoderModel.eval()
    totalLoss = 0
    hist =0
    with torch.no_grad():
        for i, autoencoderInput in enumerate(val_loader):
            autoencoderInput = autoencoderInput.cuda()
 
            autoencoderOutput, feature, enq, deq  = autoencoderModel(autoencoderInput)
            hist = hist+feature.sum(0)/autoencoderInput.shape[0]
            totalLoss += criterion_test(autoencoderOutput, autoencoderInput).item()*autoencoderInput.shape[0]
        averageLoss = totalLoss / len(test_dataset)
        loss2 = feature_criterion(enq, deq)
        print('==random split test step==')
        print(np.std(hist.cpu().numpy()))
        print(averageLoss,loss2.item())
    valavgloss = averageLoss
    totalLoss = 0
    hist =0
    with torch.no_grad():
        for i, autoencoderInput in enumerate(test_loader):
            autoencoderInput = autoencoderInput.cuda()
            
            autoencoderOutput, feature, enq, deq  = autoencoderModel(autoencoderInput)
            hist = hist+feature.sum(0)/autoencoderInput.shape[0]
            totalLoss += criterion_test(autoencoderOutput, autoencoderInput).item()*autoencoderInput.shape[0]
        averageLoss = totalLoss / len(test_dataset)
        loss2 = feature_criterion(enq, deq)
        print('==last split test step==')
 
        print(np.std(hist.cpu().numpy()))
        print(averageLoss,loss2.item())
        if averageLoss < bestLoss:
            # Model saving
            # Encoder Saving
            torch.save({'state_dict': autoencoderModel.module.encoder.state_dict(), }, './modelSubmit/encoder.pth.tar')
            # Decoder Saving
            torch.save({'state_dict': autoencoderModel.module.decoder.state_dict(), }, './modelSubmit/decoder.pth.tar')
            print("Model saved,avgloss:",averageLoss)
            bestLoss = averageLoss
            valLoss = valavgloss
        print('==show best==')
        print('valloss:', valLoss, 'testloss:',bestLoss)
        if epoch>0*50:
            scheduler.step()
    #break
#=======================================================================================================================
#=======================================================================================================================      

繼續閱讀