天天看點

JoJoGAN 實踐

JoJoGAN: One Shot Face Stylization. 隻用一張人臉圖檔,就能學習其風格,然後遷移到其他圖檔。訓練時長隻用 1~2 min 即可。

  • code
  • paper

效果:

主流程:

JoJoGAN 實踐

本文分享了個人在本地環境(非 colab)實踐 JoJoGAN 的整個過程。你也可以依照本文上手訓練自己喜歡的風格。

準備環境

安裝:

  • Anaconda
  • PyTorch
conda create -n torch python=3.9 -y
conda activate torch

conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -y
           

檢查:

$ python - <<EOF
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())
EOF
1.10.1 True
           

準備代碼

git clone https://github.com/mchong6/JoJoGAN.git
cd JoJoGAN

pip install tqdm gdown matplotlib scipy opencv-python dlib lpips wandb

# Ninja is required to load C++ extensions
wget https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip
sudo unzip ninja-linux.zip -d /usr/local/bin/
sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force
           

然後,将本文提供的幾個

*.py

放進

JoJoGAN

目錄,從這裡擷取: https://github.com/ikuokuo/start-deep-learning/tree/master/practice/JoJoGAN 。

  • download_models.py

    : 擷取模型
  • generate_faces.py

    : 生成人臉
  • stylize.py

    : 風格化
  • train.py

    : 訓練

之後,于訓練流程一節,會結合代碼,講述下 JoJoGAN 的工作流程。其他些

*.py

隻提下用法,實作就不多說了。

擷取模型

python download_models.py

擷取模型,如下:

models/
├── arcane_caitlyn_preserve_color.pt
├── arcane_caitlyn.pt
├── arcane_jinx_preserve_color.pt
├── arcane_jinx.pt
├── arcane_multi_preserve_color.pt
├── arcane_multi.pt
├── art.pt
├── disney_preserve_color.pt
├── disney.pt
├── dlibshape_predictor_68_face_landmarks.dat
├── e4e_ffhq_encode.pt
├── jojo_preserve_color.pt
├── jojo.pt
├── jojo_yasuho_preserve_color.pt
├── jojo_yasuho.pt
├── restyle_psp_ffhq_encode.pt
├── stylegan2-ffhq-config-f.pt
├── supergirl_preserve_color.pt
└── supergirl.pt
           

生成人臉

用 StyleGAN2 預訓練模型随機生成人臉,用于測試:

python generate_faces.py -n 5 -s 2000 -o input
           

使用預訓練風格

JoJoGAN 給了 8 個預訓練模型,可以一并體驗,與文首的效果圖一樣:

# 預覽 JoJoGAN 所有預訓練模型 風格化某圖檔(test_input/iu.jpeg)的效果
python stylize.py -i test_input/iu.jpeg -s all --save-all --show-all

# 使用 JoJoGAN 所有預訓練模型 風格化所有生成的測試人臉(input/*)
find ./input -type f -print0 | xargs -0 -i python stylize.py -i {} -s all --save-all
           

訓練自己的風格

首先,準備一張風格圖:

JoJoGAN 實踐

之後,開始訓練:

python train.py -n yinshi -i style_images/yinshi.jpeg --alpha 1.0 --num_iter 500 --latent_dim 512 --use_wandb --log_interval 50
           

--use_wandb

時,可檢視訓練日志:

JoJoGAN 實踐

最後,測試效果:

python stylize.py -i input/girl.jpeg --save-all --show-all --test_style yinshi --test_ckpt output/yinshi.pt --test_ref output/yinshi/style_images_aligned/yinshi.png
           
JoJoGAN 實踐

訓練工作流程

準備風格圖檔,轉為訓練資料

将風格圖檔裡的人臉裁減對齊:

# dlib 預測人臉特征點,再裁減對齊
from util import align_face
style_aligned = align_face(img_path)
           

将風格圖檔 GAN Inversion 逆映射回預訓練模型的隐向量空間(Latent Space):

name, _ = os.path.splitext(os.path.basename(img_path))
style_code_path = os.path.join(latent_dir, f'{name}.pt')

# e4e FFHQ encoder (pSp) > GAN inversion,得到 latent
from e4e_projection import projection
latent = projection(style_aligned, style_code_path, device)
           

載入 StyleGAN2 模型,訓練微調

載入預訓練模型:

latent_dim = 512

# 加載預訓練模型
original_generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load("models/stylegan2-ffhq-config-f.pt", map_location=lambda storage, loc: storage)
original_generator.load_state_dict(ckpt["g_ema"], strict=False)

# 準備微調的模型
generator = deepcopy(original_generator)
           

訓練可調參數:

# 控制風格強度 [0, 1]
alpha = 1.0
alpha = 1-alpha

# 是否保留原圖像色彩
preserve_color = True

# 訓練疊代次數(最好 500,Adam 學習率是基于 500 次疊代調優的)
num_iter = 500

# 風格圖檔 targets 及 latents
targets = ..
latents = ..
           

進行訓練,拟合隐空間。最後儲存:

# 準備 LPIPS 計算 loss
lpips_fn = lpips.LPIPS(net='vgg').to(device)

# 準備優化器
g_optim = torch.optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))

# 哪些層用于交換,用于生成風格化圖檔
if preserve_color:
    id_swap = [7,9,11,15,16,17]
else:
    id_swap = list(range(7, generator.n_latent))

# 訓練疊代
for idx in tqdm(range(num_iter)):
    # 交換層混合風格,并加噪聲
    mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim])
        .to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)
    in_latent = latents.clone()
    in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha)*mean_w[:, id_swap]

    # 以 latent 風格化圖檔,與目标風格對比
    img = generator(in_latent, input_is_latent=True)
    loss = lpips_fn(F.interpolate(img, size=(256,256), mode='area'),
        F.interpolate(targets, size=(256,256), mode='area')).mean()

    # 優化
    g_optim.zero_grad()
    loss.backward()
    g_optim.step()

# 儲存權重,完成
torch.save({"g": generator.state_dict()}, save_path)
           

結語

JoJoGAN 實踐下來效果不錯。使用本文給到的代碼,更容易上手訓練自己喜歡的風格,值得試試。

繼續閱讀