教你如何使用GAN为口袋妖怪上色
阿丽66
发表于 2023-1-14 11:50:51
239
0
0
在之前的Demo中,我们使用了条件GAN来生成了手写数字图像。那么除了生成数字图像以外我们还能用神经网络来干些什么呢?! Y( F* E: I9 i+ B1 `/ j6 s( H! Z
在本案例中,我们用神经网络来给口袋妖怪的线框图上色。1 j4 @" L) x C' k
第一步: 导入使用库3 _$ e0 [. Z) F$ m2 L
from __future__ import absolute_import, division, print_function, unicode_literals( L/ f# I" N- O3 S! \, g3 d2 I
import tensorflow as tf( \; F; N; N( s
tf.enable_eager_execution()) Q$ E% ^& @9 U N
import numpy as np
import pandas as pd
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output) ` _. v! v0 z9 V' j ^' p
口袋妖怪上色的模型训练过程中,需要比较大的显存。为了保证我们的模型能在2070上顺利的运行,我们限制了显存的使用量为90%, 来避免显存不足的引起的错误。+ ?: u- g8 [* A, Q ]5 s
config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.9
session = tf.compat.v1.Session(config=config)* P: d0 k( y1 c. _1 y4 @
定义需要使用到的常量。
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256 c# [" B- {' Y# u1 x
IMG_HEIGHT = 256
PATH = 'dataset/'
OUTPUT_CHANNELS = 3
LAMBDA = 100
EPOCHS = 10 A7 {( ^7 h. X# V- Z1 M& B2 B
第二步: 定义需要使用的函数5 n6 `; A3 Z' c4 ` S. T1 j8 {
图片数据加载函数,主要的作用是使用Tensorflow的io接口读入图片,并且放入tensor的对象中,方便后续使用$ L$ o6 g) F5 Q4 Y
def load(image_file):* ?3 c; y) t+ q% U% P4 a( C
image = tf.io.read_file(image_file)% T+ e1 C) C% m* z1 u* Z3 V
image = tf.image.decode_jpeg(image)
w = tf.shape(image)[1]
w = w // 2" x: C+ U4 d5 @* y! x4 x
input_image = image[:, :w, :]
real_image = image[:, w:, :]
input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32), O( E: V; J! l# H
return input_image, real_image
tensor对象转成numpy对象的函数
在训练过程中,我会可视化一些训练的结果以及中间状态的图片。Tensorflow的tensor对象无法直接在matplot中直接使用,因此我们需要一个函数,将tensor转成numpy对象。& p9 b7 n, f! W8 {1 U
def tensor_to_array(tensor1):1 \# e; F, g G) T, I$ M
return tensor1.numpy()8 j5 _/ n& U3 D% n2 L% V) i
第三步: 数据可视化
我们先来看下我们的训练数据长成什么样。, X* T( `- a# ^) J j
我们每张数据图片分成了两个部分,左边部分是线框图,我们用来作为输入数据,右边部分是上色图,我们用来作为训练的目标图片。! J! Q. \) p( ` i& t! U* f0 ^2 n1 {
我们使用上面定义的load函数来加载一张图片看下
input, real = load(PATH+'train/114.jpg')0 c- C# w6 E2 E. P
plt.figure()
plt.imshow(tensor_to_array(input)/255.0)
plt.figure()9 k8 D% k- B) @5 r
plt.imshow(tensor_to_array(real)/255.0)
第四步: 数据增强/ J# C# @1 A {4 i7 w, s
由于我们的训练数据不够多,我们使用数据增强来增加我们的样本。从而让小样本的数据也能达到更好的效果。
我们采取如下的数据增强方案:
图片缩放, 将输入数据的图片缩放到我们指定的图片的大小随机裁剪数据归一化左右翻转 M5 y, m, ?& P' [- H) k
def resize(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image+ W% f- a' H+ }
def random_crop(input_image, real_image):; F0 E Y7 ?8 z2 K
stacked_image = tf.stack([input_image, real_image], axis=0) U7 F$ Q7 E& U/ A
cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]! _' c9 l: R- a* [2 _: F# g# U. E
def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)% U8 l7 J6 N- v" A
cropped_image = tf.image.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]
我们将上述的增强方案做成一个函数,其中左右翻转是随机进行; a) I% c/ f& Y8 K' F; I2 u3 _
@tf.function()5 n9 o3 G+ ?* s/ W: g0 v9 j' m2 @
def random_jitter(input_image, real_image):
input_image, real_image = resize(input_image, real_image, 286, 286)
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)9 D; r8 g7 v* m) I# F
real_image = tf.image.flip_left_right(real_image)6 k) S$ Q8 z( L+ F9 o8 q
return input_image, real_image7 ?; z3 q) q( P; U; H
数据增强的效果
plt.figure(figsize=(6, 6))
for i in range(4):
input_image, real_image = random_jitter(input, real) q$ L2 d0 ~2 Y# N# B3 f% d
plt.subplot(2, 2, i+1)* [! y, r" Q0 h% r X. l$ c
plt.imshow(tensor_to_array(input_image)/255.0)
plt.axis('off')
plt.show()
第五步: 训练数据的准备
定义训练数据跟测试数据的加载函数
def load_image_train(image_file):/ N, M0 q) @9 `
input_image, real_image = load(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize(input_image, real_image)" e0 [0 u B* |: Q1 B: M
return input_image, real_image
def load_image_test(image_file):
input_image, real_image = load(image_file)' ^& Y1 [, H+ v `0 L+ @- ]
input_image, real_image = resize(input_image, real_image, IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize(input_image, real_image)9 n# ?" j P9 r0 Q
return input_image, real_image- \4 c# h e6 @! \
使用tensorflow的DataSet来加载训练和测试数据, 定义我们的训练数据跟测试数据集对象
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')" U. D6 u4 P, l
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)* G1 R" C+ u% x( l2 v
train_dataset = train_dataset.cache().shuffle(BUFFER_SIZE)/ \5 Q) w; `- ?: Q; A4 X
train_dataset = train_dataset.batch(1)
test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1) H, C. U9 l, C, r- o
第六步: 定义模型5 h2 X. a8 m9 `! E$ d: e5 \1 y
口袋妖怪的上色,我们使用的是GAN模型来训练, 相比上个条件GAN生成手写数字图片,这次的GAN模型的复杂读更加的高。
我们先来看下生成网络跟判别网络的整体结构
生成网络) O8 n$ i* c% Z
生成网络使用了U-Net的基本框架,编码阶段的每一个Block我们使用, 卷积层->BN层->LeakyReLU的方式。解码阶段的每一个Block我们使用, 反卷积->BN层->Dropout或者ReLU。其中前三个Block我们使用Dropout, 后面的我们使用ReLU。每一个编码层的Block输出还连接了与之对应的解码层的Block. 具体可以参考U-Net的skip connection.9 B, w2 s4 O' `4 ?# d0 C" N
定义编码Block6 \ J n: [$ b' b
def downsample(filters, size, apply_batchnorm=True):
initializer = tf.random_normal_initializer(0., 0.02)$ X5 @/ a& a% q$ x2 |) N' b# {/ |
result = tf.keras.Sequential()% S. Q7 I$ M; y J+ q0 h4 ?! A7 `
result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False)) d3 L/ t. @" v8 K5 B, V) A2 R, ?
if apply_batchnorm:: R) U4 j! N3 m3 x7 L7 @" X$ Z* o
result.add(tf.keras.layers.BatchNormalization())
result.add(tf.keras.layers.LeakyReLU())
return result1 [& R& s" ~5 ~) V
down_model = downsample(3, 4)* K4 x u2 c, w6 J' V% c X
定义解码Block
def upsample(filters, size, apply_dropout=False):" u! L5 | D7 F) [$ c: U
initializer = tf.random_normal_initializer(0., 0.02)) {6 g8 ~/ \0 H% q2 ?2 H6 {
result = tf.keras.Sequential()
result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
result.add(tf.keras.layers.BatchNormalization())6 F0 S* E6 g. R- P, n8 f% Y! _
if apply_dropout:7 @5 _( S8 R* {) S' }# V0 G" r
result.add(tf.keras.layers.Dropout(0.5))
result.add(tf.keras.layers.ReLU())
return result. O8 _ j8 s) r: q8 I/ v6 g: W
up_model = upsample(3, 4)
定义生成网络模型4 q% b& `' L! Y
def Generator():- Y5 K" g1 S$ E8 d" v1 p+ C
down_stack = [& H6 K; v6 K. j3 e
downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
downsample(128, 4), # (bs, 64, 64, 128)
downsample(256, 4), # (bs, 32, 32, 256)
downsample(512, 4), # (bs, 16, 16, 512)
downsample(512, 4), # (bs, 8, 8, 512)4 x g! J0 y+ |/ g) g, k5 ?) t$ s
downsample(512, 4), # (bs, 4, 4, 512)
downsample(512, 4), # (bs, 2, 2, 512)
downsample(512, 4), # (bs, 1, 1, 512)! B* l. i! T& D. u+ D* }0 W" J
]6 v# n, H. b* p; @
up_stack = [! a; u+ u' j+ b; O" \* E, G+ E
upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)3 j2 M" Y) m, W" }; C! W) }% ?+ J
upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)# o: O( w8 x \0 C
upsample(512, 4), # (bs, 16, 16, 1024)4 m$ u3 D6 u3 R
upsample(256, 4), # (bs, 32, 32, 512)
upsample(128, 4), # (bs, 64, 64, 256)( Z/ M7 \1 B, p* }% k3 M2 i
upsample(64, 4), # (bs, 128, 128, 128)
]
initializer = tf.random_normal_initializer(0., 0.02)9 d4 M% Q2 w/ F9 ~, w4 v a u# Z
last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
strides=2,
padding='same',9 g0 q) f7 }' H8 w2 v# x5 ]: c
kernel_initializer=initializer,: B" C' r0 A; I1 }5 u( Z- F' q
activation='tanh') # (bs, 256, 256, 3)
concat = tf.keras.layers.Concatenate()$ Z, [4 G G* E. C
inputs = tf.keras.layers.Input(shape=[None,None,3])- b9 a9 C+ m: W' f7 K
x = inputs+ E+ T/ ]: ]4 w
skips = []7 A& @' E9 l& g. K7 Q
for down in down_stack:
x = down(x)' @, |- T- O6 ~) n
skips.append(x)8 [: n* n. F N2 `9 U$ ?5 i
skips = reversed(skips[:-1])/ l ?. Z) ~2 F; r) d
for up, skip in zip(up_stack, skips):
x = up(x)
x = concat([x, skip])
x = last(x)# O& {5 Z/ K" N5 ]8 S9 I
return tf.keras.Model(inputs=inputs, outputs=x)
generator = Generator()# z2 ]) s5 O# e
判别网络
判别网络我们使用PatchGAN, PatchGAN又称之为马尔可夫判别器。传统的基于CNN的分类模型有很多都是在最后引入了一个全连接层,然后将判别的结果输出。然而PatchGAN却不一样,它完全由卷积层构成,最后输出的是一个纬度为N的方阵。然后计算矩阵的均值作真或者假的输出。从直观上看,输出方阵的每一个输出,是模型对原图中的一个感受野,这个感受野对应了原图中的一块地方,也称之为Patch,因此,把这种结构的GAN称之为PatchGAN。7 Y3 |& N2 ]% M5 w* ~# k
PatchGAN中的每一个Block是由卷积层->BN层->Leaky ReLU组成的。% W$ H1 v0 m) o1 q: v( A
在我们的这个模型中,最后一层我们的输出的纬度是(Batch Size, 30, 30, 1), 其中1表示图片的通道。
每个30x30的输出对应着原图的70x70的区域。详细的结构可以参考这篇论文。
def Discriminator():
initializer = tf.random_normal_initializer(0., 0.02)
inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')* v9 W' X* x: h# m! |9 L1 s
# (batch size, 256, 256, channels*2)! Q' W' [9 h. K# m! P" x1 o
x = tf.keras.layers.concatenate([inp, tar])
# (batch size, 128, 128, 64)! E, p# E5 M6 F1 z5 I' m" \
down1 = downsample(64, 4, False)(x)) d7 D8 i" ^# W8 H
8 j; c4 ]$ j* v+ K* ^5 ?
# (batch size, 64, 64, 128)3 }6 r* |3 I& W& _4 `
down2 = downsample(128, 4)(down1)
# (batch size, 32, 32, 256)
down3 = downsample(256, 4)(down2)
# (batch size, 34, 34, 256)
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)2 b7 n# o8 A u b% E( O
2 i% E3 D. o# X5 I
# (batch size, 31, 31, 512)
conv = tf.keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1)! D+ v+ q/ B- a8 B" N- k
batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
# (batch size, 33, 33, 512)
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)0 i# x! F0 }5 R- V$ q: ^2 F! L9 u
# (batch size, 30, 30, 1)$ o- P! ~4 v1 }& K6 n
last = tf.keras.layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2)
return tf.keras.Model(inputs=[inp, tar], outputs=last)
discriminator = Discriminator()
第七步: 定义损失函数和优化器7 _5 \" {& C2 _8 a" U+ C: C1 ^: z
**! W% D' k$ R7 |, [# x6 q8 w
**, G& {1 {/ p$ j: @ h6 ~
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)2 d) S: w# f3 N: X$ m$ `
**
( ?" m0 ? ~ r) G1 f
def discriminator_loss(disc_real_output, disc_generated_output):; m! u& d& z9 p' X, X) x+ C
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
total_disc_loss = real_loss + generated_loss
return total_disc_loss
def generator_loss(disc_generated_output, gen_output, target):7 m- m9 m6 E% \5 U, c
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (LAMBDA * l1_loss)8 N. T5 C# v5 S( s! \
return total_gen_loss
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)' X7 H1 L& ^8 z( d y
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
' r# v% x/ d9 _6 z& X+ u
第八步: 定义CheckPoint函数 F) E, h8 Z. O, x4 ^3 C5 x; t/ z
由于我们的训练时间较长,因此我们会保存中间的训练状态,方便后续加载继续训练! ?* L! I3 z5 R" p) A* q+ u# c
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,: i# H2 Y% v% y8 Z/ r
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)# q/ s7 J. q0 _7 k& s0 \4 f* L: Y
如果我们保存了之前的训练的结果,我们加载保存的数据。然后我们应用上次保存的模型来输出下我们的测试数据。
def generate_images(model, test_input, tar):
prediction = model(test_input, training=True)
plt.figure(figsize=(15,15))
display_list = [test_input[0], tar[0], prediction[0]]5 g% t8 S% Z; |5 Y7 }5 {
title = ['Input', 'Target', 'Predicted']
for i in range(3):
plt.subplot(1, 3, i+1); w( i4 ~5 r6 G- Q# r
plt.title(title)
plt.imshow(tensor_to_array(display_list) * 0.5 + 0.5)$ {! r" O" Y% I( E6 Z, k% q
plt.axis('off')
plt.show()% v* _" J7 P" o- H$ J! E
ckpt_manager = tf.train.CheckpointManager(checkpoint, "./", max_to_keep=2)8 N: t* S4 R5 o+ u1 F- e& e
if ckpt_manager.latest_checkpoint:1 }& u* m& J6 d
checkpoint.restore(ckpt_manager.latest_checkpoint)) h, P, A0 ?1 P8 a- y* q/ ^- ^
for inp, tar in test_dataset.take(20):' }$ i- t% u- x0 X' h# `% `* b
generate_images(generator, inp, tar)
第九步: 训练) S! u& X m* d! K3 V$ [6 T6 S9 C5 i
在训练中,我们输出第一张图片来查看每个epoch给我们的预测结果带来的变化。让大家感受到其中的乐趣
每20个epoch我们保存一次状态
@tf.function/ Y9 Y& C' E0 |; m2 X- F
def train_step(input_image, target):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = generator(input_image, training=True)
disc_real_output = discriminator([input_image, target], training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
gen_loss = generator_loss(disc_generated_output, gen_output, target)& ? U5 K5 U5 @" ^, {
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
generator_gradients = gen_tape.gradient(gen_loss,5 Q; n+ B! j" U
generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss,
discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients,0 ~& J0 D" [+ x/ Z
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,; u7 w- ~0 f: l( C% k5 ?
discriminator.trainable_variables))
def fit(train_ds, epochs, test_ds):0 t. n, |% W2 a& s: D- a/ D
for epoch in range(epochs):
start = time.time()
for input_image, target in train_ds:
train_step(input_image, target)) i% \8 j, a1 Z% U" w
clear_output(wait=True)0 q. | f6 a1 x L
' O4 |3 J6 t: r. W
for example_input, example_target in test_ds.take(1):
generate_images(generator, example_input, example_target)/ I4 ]" R5 [. B' L$ `+ |2 C
if (epoch + 1) % 20 == 0:
ckpt_save_path = ckpt_manager.save()& ], |# v# z$ k7 G
print ('保存第{}个epoch到{}\n'.format(epoch+1, ckpt_save_path))' _! q$ m' M$ V. u3 d2 Z
print ('训练第{}个epoch所用的时间为{:.2f}秒\n'.format(epoch + 1, time.time()-start))
fit(train_dataset, EPOCHS, test_dataset)
- i* v9 B4 \. ~
训练第8个epoch所用的时间为51.33秒。
第十步: 使用测试数据上色,查看下我们的效果* N0 B5 K+ [: X' i [
for input, target in test_dataset.take(20):
generate_images(generator, input, target)$ E, Q- D% T# I z0 U# J: c% @8 t
矩池云现在已经上架 “口袋妖怪上色” 镜像;感兴趣的小伙伴可以通过矩池云官网“Jupyter 教程 Demo” 镜像中尝试使用。% y% J& ~$ l/ B) }+ J
成为第一个吐槽的人