Pytorch学习笔记DCGAN极简入门教程

 更新时间:2021年09月07日 09:42:26   作者:xz1308579340  
网上GAN的教程太多了,这边也谈一下自己的理解,本文给大家介绍一下GAN的两部分组成,有需要的朋友可以借鉴参考下,希望能够有所帮助

1.图片分类网络

这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片

2.图片生成网络

输入是一个随机噪声,输出是一张图片,使用的是反卷积层

相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了

首先是图片分类网络:

简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input)

重点是生成网络

代码如下,其实就是反卷积+bn+relu

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )
    def forward(self, input):
        return self.main(input)


讲道理,以上两个网络都挺简单。

真正的重点到了,怎么训练

每一个step分为三个步骤:

  • 训练二分类网络
    1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
    2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络
  • 训练生成网络
    3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络

不多说直接上代码

for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        iters += 1

以上就是Pytorch学习笔记DCGAN极简入门教程的详细内容,更多关于Pytorch学习DCGAN入门教程的资料请关注脚本之家其它相关文章!

相关文章

  • 详解python使用递归、尾递归、循环三种方式实现斐波那契数列

    详解python使用递归、尾递归、循环三种方式实现斐波那契数列

    本篇文章主要介绍了python使用递归、尾递归、循环三种方式实现斐波那契数列,非常具有实用价值,需要的朋友可以参考下
    2018-01-01
  • python经典练习百题之猴子吃桃三种解法

    python经典练习百题之猴子吃桃三种解法

    这篇文章主要给大家介绍了关于python经典练习百题之猴子吃桃三种解法的相关资料, Python猴子吃桃子编程是一个趣味性十足的编程练习,在这个练习中,我们将要使用Python语言来模拟一只猴子吃桃子的过程,需要的朋友可以参考下
    2023-10-10
  • tensorflow之读取jpg图像长和宽实例

    tensorflow之读取jpg图像长和宽实例

    这篇文章主要介绍了tensorflow之读取jpg图像长和宽实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-06-06
  • 关于Python正则表达式 findall函数问题详解

    关于Python正则表达式 findall函数问题详解

    在写正则表达式的时候总会遇到不少的问题,本文讲述了Python正则表达式中 findall()函数和多个表达式元组相遇的时候会出现的问题
    2018-03-03
  • Django 解决由save方法引发的错误

    Django 解决由save方法引发的错误

    这篇文章主要介绍了Django 解决由save方法引发的错误,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-05-05
  • Python和Anaconda和Pycharm安装教程图文详解

    Python和Anaconda和Pycharm安装教程图文详解

    PyCharm是一种PythonIDE,带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具,这篇文章主要介绍了Python和Anaconda和Pycharm安装教程,需要的朋友可以参考下
    2020-02-02
  • 用Python解决x的n次方问题

    用Python解决x的n次方问题

    今天小编就为大家分享一篇用Python解决x的n次方问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-02-02
  • Python pluggy模块的用法示例演示

    Python pluggy模块的用法示例演示

    这篇文章主要介绍了Python pluggy模块的用法,pluggy提供了一个简易便捷的插件系统,可以做到插件与主题功能松耦合,pluggy 是pytest,tox,devpi的核心框架文中通过代码示例演示给大家详细介绍,需要的朋友参考下吧
    2022-05-05
  • Python实现基于多线程、多用户的FTP服务器与客户端功能完整实例

    Python实现基于多线程、多用户的FTP服务器与客户端功能完整实例

    这篇文章主要介绍了Python实现基于多线程、多用户的FTP服务器与客户端功能,结合完整实例形式分析了Python多线程、多用户FTP服务器端与客户端相关实现技巧与注意事项,需要的朋友可以参考下
    2017-08-08
  • Python出现segfault错误解决方法

    Python出现segfault错误解决方法

    这篇文章主要介绍了Python出现segfault错误解决方法,分析了系统日志提示segfault错误的原因与对应的解决方法,需要的朋友可以参考下
    2016-04-04

最新评论