In this lab, you will be training a GAN to make new pokemon sprites. The provided dataset consists of 15,467 sprites of all pokemon from generation 1 to generation 8, each of which has size 64 × 64 × 3
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
Download the file: http://www.donlapark.cmustat.com/229352/pokemon.zip and upload it to your google Drive.
!unzip /content/drive/MyDrive/pokemon.zip;
!mkdir /content/drive/MyDrive/new_pokemon; # Folder to save images of new pokemon
!mkdir /content/drive/MyDrive/GAN_weights # Folder to save images of new pokemon
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as tt
#TODO-0: set the training parameters below
BATCH_SIZE =
EPOCHS =
noise_dim =
num_examples_to_generate = 16
IMAGE_SIZE = 64
#for normalization
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
#preprocessing+augmentation
Transform = tt.Compose([
tt.RandomRotation(degrees=15,fill=(255,255,255)),
tt.ColorJitter(hue=0.5),
tt.ToTensor(),
tt.Normalize(*stats), #normalize pixels to [-1,1]
])
data_dir = "pokemon"
dataset = ImageFolder(data_dir, transform=Transform)
dataset.classes = ['pokemon'] #needs to be the name of the subfolder
dataloader = DataLoader(dataset,
batch_size = BATCH_SIZE,
shuffle = True,
num_workers = 2)
#Generator transforms a Gaussian vector into a 64x64x3 image
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
#define all layers that we need
self.linear = nn.Linear(in_features = noise_dim,
out_features = 4*4*64*8,
bias=False)
self.main = nn.Sequential(
# Current shape = 512 x 4 x 4
# See https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
nn.ConvTranspose2d(in_channels = 512,
out_channels = 256,
kernel_size = 4,
stride = 2,
padding = 1,
bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
# Current shape = 256 x 8 x 8
# TODO-1: fill the rest of the model
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
# Output shape = 3 x 64 x 64
)
def forward(self, input):
#build the model using layers that we just defined
x = self.linear(input)
x = x.view(-1, 64*8, 4, 4) #reshape
out = self.main(x)
return out
netG = Generator().to(device)
# Load the generator to resume training
#checkpointG = torch.load('/content/drive/MyDrive/GAN_weights/netG_0020')
#netG.load_state_dict(checkpointG['model_state_dict'])
LeakyRelu
¶class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# input is 3 x 64 x 64
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
# Current shape = 64 x 32 x 32
# TODO-2: fill the rest of the model
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
netD = Discriminator().to(device)
# Load the discriminator to resume training
#checkpointD = torch.load('/content/drive/MyDrive/GAN_weights/netD_0020')
#netD.load_state_dict(checkpointD['model_state_dict'])
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(num_examples_to_generate, noise_dim, device=device)
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
# Setup Adam optimizers for both G and D
# TODO-4: define the optimizers of the discriminator with learning rate = 0.0002
# and betas = (0.5, 0.999)
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD =
# Load the optimizers to resume training
#optimizerG.load_state_dict(checkpointG['optimizer_state_dict'])
#optimizerD.load_state_dict(checkpointD['optimizer_state_dict'])
criterion
defined above.¶# Training Loop
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(EPOCHS):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_img = data[0].to(device)
real_size = real_img.shape[0]
label = torch.full((real_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_img).view(-1)
###########################
# TODO-5.1: Calculate the loss between the output and the label
###########################
errD_real =
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(real_size, noise_dim, device=device)
# Generate fake image batch with G
fake_img = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake_img.detach()).view(-1)
###########################
# TODO-5.2: Calculate the loss between the output and the label
###########################
errD_fake =
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
errD_fake.backward()
D_G_z1 = output.mean().item()
###########################
# TODO-5.3: Compute the sum of the real and fake losses
###########################
errD =
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake_img).view(-1)
###########################
# TODO-5.4: Calculate the loss between the output and the label
###########################
errG =
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, EPOCHS, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
iters += 1
# Check how the generator is doing by saving G's output on fixed_noise
with torch.no_grad():
predictions = netG(fixed_noise).detach().cpu().numpy()
# move colour channel from dim 1 to dim 3 (C,H,W) -> (H,W,C)
predictions = np.moveaxis(predictions,1, 3)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow((predictions[i, :, :, :] +1)/2)
plt.axis('off')
plt.savefig('/content/drive/MyDrive/new_pokemon/image_at_epoch_{:04d}.png'.format(epoch))
torch.save({
'model_state_dict': netG.state_dict(),
'optimizer_state_dict': optimizerG.state_dict()
}, '/content/drive/MyDrive/GAN_weights/netG_{:04d}'.format(epoch))
torch.save({
'model_state_dict': netD.state_dict(),
'optimizer_state_dict': optimizerD.state_dict()
}, '/content/drive/MyDrive/GAN_weights/netD_{:04d}'.format(epoch))
!python -m pip install denoising_diffusion_pytorch --quiet
|████████████████████████████████| 148 kB 8.6 MB/s eta 0:00:01
import gc
# Clear the memory from training DCGAN
torch.cuda.empty_cache()
del netD, netG, dataloader, fake_img, predictions
gc.collect()
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8)
).cuda()
diffusion = GaussianDiffusion(
model,
image_size = 64,
timesteps = 1000, # number of steps
loss_type = 'l1' # L1 or L2
).cuda()
trainer = Trainer(
diffusion,
'pokemon/pokemon',
train_batch_size = 32,
train_lr = 2e-5,
train_num_steps = 100000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = False # turn on mixed precision
)
# trainer.load('2') # Assume the current epoch is 2
# trainer.opt.param_groups[0]['capturable'] = True
trainer.train()
0%| | 0/100000 [00:00<?, ?it/s]
sampling loop time step: 0%| | 0/1000 [00:00<?, ?it/s]
sampling loop time step: 0%| | 0/1000 [00:00<?, ?it/s]
sampling loop time step: 0%| | 0/1000 [00:00<?, ?it/s]
sampling loop time step: 0%| | 0/1000 [00:00<?, ?it/s]
sampled_images = diffusion.sample(batch_size = 4)
imgs = sampled_images.cpu().numpy()
for i in range(imgs.shape[0]):
plt.figure()
plt.imshow(imgs[i].reshape(64, 64, 3))