昇腾社区首页
中文
注册
基于香橙派AIpro+MindSpore实现GAN图像生成

基于香橙派AIpro+MindSpore实现GAN图像生成

香橙派AIpro

发表于 2025/04/24

基于香橙派AIpro+MindSpore实现GAN图像生成

01引言

    生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。本案例介绍在OrangePi AIpro开发板上基于MindSpore框架实现GAN图像生成。

02模型简介

      最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):

(1)生成器的任务是生成看起来像训练图像的“假”图像;

(2)判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

      GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 G和估计样本是否来自训练数据的判别模型 D。

      在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

      用x代表图像数据,用D(x)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,D(x)需要处理作为二进制文件的大小为

1×28×28的图像数据。当x来自训练数据时,D(x)数值应该趋近于1;而当x来自生成器时,D(x)数值应该趋近于0。因此D(x)也可以被认为是传统的二分类器。

    用z代表标准正态分布中提取出的隐码(隐向量),用G(z):表示将隐码(隐向量)z映射到数据空间的生成器函数。函数G(z)的目标是将服从高斯分布的随机噪声z通过生成网络变换为近似于真实分布P[data(x)]的数据分布,我们希望找到θ使得P[G(x;θ)]和P[data(x)]尽可能的接近,其中θ代表网络参数。

     D(G(z))表示生成器G生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,D和G在进行一场博弈,D想要最大程度的正确分类真图像与假图像,也就是参数logD(x);而G试图欺骗D来最小化假图像被识别到的概率,也就是参数log(1−D(G(z)))。因此GAN的损失函数为:

      从理论上讲,此博弈游戏的平衡点是P[G(x;θ)]=P[data(x)],此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

(1)在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。

(2)判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定0。

(3)生成器通过优化,生成出更加贴近真实数据分布的数据。

(4)生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。


在上图中,蓝色虚线表示判别器,黑色虚线表示真实数据分布,绿色实线表示生成器生成的虚假数据分布,zz 表示隐码,xx 表示生成的虚假图像 G(z)。该图片来源于Generative Adversarial Nets。详细的训练方法介绍见原论文。

03环境准备

开发者拿到香橙派开发板后,首先需要进行硬件资源确认,镜像烧录及CANN和MindSpore版本的升级,才可运行该案例,具体如下:

硬件: 香橙派AIpro 16G 8-12T开发板
镜像: 香橙派官网ubuntu镜像
CANN:8.0.RC3.alpha002
MindSpore: 2.4.10

1、镜像烧录

运行该案例需要烧录香橙派官网ubuntu镜像,烧录流程参考昇思MindSpore官网--香橙派开发专区--环境搭建指南--镜像烧录章节。

2、CANN升级

CANN升级参考昇思MindSpore官网--香橙派开发专区--环境搭建指南--CANN升级章节。

3、MindSpore升级

MindSpore升级参考昇思MindSpore官网--香橙派开发专区--环境搭建指南--MindSpore升级章节。

04实验过程

1、设置运行环境

import mindspore
mindspore.set_context(max_device_memory="2GB", mode=mindspore.GRAPH_MODE, device_target="Ascend",  \
jit_config={"jit_level":"O2"}, ascend_config={"precision_mode":"allow_mix_precision"})

参数解释:

  • max_device_memory="2GB" : 设置设备可用的最大内存为2GB。

  • mode=mindspore.GRAPH_MODE : 表示在GRAPH_MODE模式中运行。

  • device_target="Ascend" : 表示待运行的目标设备为Ascend。

  • jit_config={"jit_level":"O2"} : 编译优化级别开启极致性能优化,使用下沉的执行方式。

  • ascend_config={"precision_mode":"allow_mix_precision"} : 自动混合精度,自动将部分算子的精度降低到float16或bfloat16。

2、数据集准备与处理

本案例将使用MNIST手写数字数据集来训练一个生成式对抗网络,使用该网络模拟生成手写数字图片。

1.数据集简介

MNIST手写数字数据集是NIST数据集的子集,共有70000张手写数字图片,包含60000张训练样本和10000张测试样本,数字图片为二进制文件,图片大小为28*28,单通道。图片已经预先进行了尺寸归一化和中心化处理。

2.数据集下载

使用download接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用pip install download安装download包。

安装download包

!pip install download

下载数据集并解压

# 数据下载
from download import download
        
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)    

解压后的数据集目录结构如下:

./MNIST_Data/
├─ train
│ ├─ train-images-idx3-ubyte
│ └─ train-labels-idx1-ubyte
└─ test
├─ t10k-images-idx3-ubyte
└─ t10k-labels-idx1-ubyte

3.数据加载

使用MindSpore自己的MnistDatase接口,读取和解析MNIST数据集的源文件构建数据集。然后对数据进行一些前处理,包含数据转换、数据增强、批量处理。

import numpy as np
import mindspore.dataset as ds
        
batch_size = 128
latent_size = 100  # 隐码的长度
        
train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')
        
def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False)
    # 数据增强
    mnist_ds = dataset1.map(
           operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
           output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])
        
    # 批量操作
    mnist_ds = mnist_ds.batch(batch_size, True)
        
    return mnist_ds
        
mnist_ds = data_load(train_dataset)
        
iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
        

4.数据集可视化

通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。

import matplotlib.pyplot as plt
        
data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):
      image = data_iter['image'][idx]
      figure.add_subplot(rows, cols, idx)
      plt.axis("off")
      plt.imshow(image.squeeze(), cmap="gray")
plt.show()

3、模型构建

本案例实现中所搭建的 GAN 模型结构与原论文中提出的 GAN 结构大致相同,但由于所用数据集 MNIST 为单通道小尺寸图片,可识别参数少,便于训练,我们在判别器和生成器中采用全连接网络架构和 ReLU 激活函数即可达到令人满意的效果,且省略了原论文中用于减少参数的 Dropout 策略和可学习激活函数 Maxout。

1.生成器

生成器 Generator 的功能是将隐码映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的灰度图像(或 RGB 彩色图像)。在本案例演示中,该功能通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内。注意实例化生成器之后需要修改参数的名称,不然静态图模式下会报错。

from mindspore import nn
import mindspore.ops as ops
        
        
img_size = 28  # 训练图像长(宽)
        
class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())
        
    def construct(self, x):
          img = self.model(x)
          return ops.reshape(img, (-1, 1, 28, 28))
        
net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

2.判别器

判别器 Discriminator 是一个二分类网络模型,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。注意实例化判别器之后需要修改参数的名称,不然静态图模式下会报错。

 # 判别器
class Discriminator(nn.Cell):
      def __init__(self, auto_prefix=True):
           super().__init__(auto_prefix=auto_prefix)
           self.model = nn.SequentialCell()
           # [N, 784] -> [N, 512]
           self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
           self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
           # [N, 512] -> [N, 256]
           self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
           self.model.append(nn.LeakyReLU())
           # [N, 256] -> [N, 1]
           self.model.append(nn.Dense(256, 1))
           self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]
        
      def construct(self, x):
          x_flat = ops.reshape(x, (-1, img_size * img_size))
          return self.model(x_flat)
        
net_d = Discriminator()
net_d.update_parameters_name('discriminator')

4、权重加载

下载已有的生成器网络模型参数文件加载到生成器网络中

import mindspore as ms
        
# 下载权重
gan_url = "https://modelers.cn/coderepo/web/v1/file/MindSpore-Lab/cluoud_obs/main/media/examples/mindspore-courses/orange-pi-online-infer/09-GAN/Generator199.ckpt"
path = "./Generator199.ckpt"
        
ckpt_path = download(gan_url, path, replace=True)
parameter = ms.load_checkpoint(ckpt_path)
ms.load_param_into_net(net_g, parameter)

5、模型推理

通过加载生成器网络模型参数文件来生成图像

from tqdm import tqdm
        
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in tqdm(range(25), desc="Generating images"):
    fig.add_subplot(5, 5, i + 1)
    plt.axis("off")
    plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()

本案例已同步上线 GitHub 仓,更多案例开发亦可参考该仓库。

05课程视频

本实验配套视频教程,供参考学习。


06实验总结

本实验在香橙派AIpro开发板,基于MindSpore实现图像生成。通过本课程的学习,您可以了解GAN基础知识、MindSpore应用以及如何使用香橙派AIpro开发板完成GAN图像生成

本页内容