Practical Implementation: Building a Simple GAN with PyTorch
๐งช Implementing a Basic GAN
Implementing a basic Generative Adversarial Network (GAN) provides hands-on understanding of its core components and training dynamics:
โ๏ธ Setup and Hyperparameters
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 64
๐งฑ Generator Network
# Generator Network
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(True),
nn.Linear(128, 256),
nn.ReLU(True),
nn.Linear(256, 28 * 28),
nn.Tanh()
)
def forward(self, z):
return self.model(z).view(-1, 1, 28, 28)
๐ Discriminator Network
# Discriminator Network
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
๐ Model Initialization and Optimizers
# Initialize models
G = Generator()
D = Discriminator()
# Optimizers
optimG = optim.Adam(G.parameters(), lr=lr)
optimD = optim.Adam(D.parameters(), lr=lr)
๐ฅ Data Loading
# Data loader
transform = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
๐ Training Loop (Simplified)
# Training Loop (simplified)
for epoch in range(5):
for real_imgs, _ in loader:
# Train Discriminator
z = torch.randn(batch_size, latent_dim)
fake_imgs = G(z)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Discriminator loss
real_loss = nn.BCELoss()(D(real_imgs), real_labels)
fake_loss = nn.BCELoss()(D(fake_imgs.detach()), fake_labels)
d_loss = real_loss + fake_loss
optimD.zero_grad()
d_loss.backward()
optimD.step()
# Train Generator
gen_loss = nn.BCELoss()(D(fake_imgs), real_labels)
optimG.zero_grad()
gen_loss.backward()
optimG.step()