Practical Implementation: Building a Simple GAN with PyTorch

Advanced

๐Ÿงช Implementing a Basic GAN

Implementing a basic Generative Adversarial Network (GAN) provides hands-on understanding of its core components and training dynamics:


โš™๏ธ Setup and Hyperparameters

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 64

๐Ÿงฑ Generator Network

# Generator Network
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 28 * 28),
            nn.Tanh()
        )
    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

๐Ÿ” Discriminator Network

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, img):
        return self.model(img)

๐Ÿš€ Model Initialization and Optimizers

# Initialize models
G = Generator()
D = Discriminator()

# Optimizers
optimG = optim.Adam(G.parameters(), lr=lr)
optimD = optim.Adam(D.parameters(), lr=lr)

๐Ÿ“ฅ Data Loading

# Data loader
transform = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

๐Ÿ” Training Loop (Simplified)

# Training Loop (simplified)
for epoch in range(5):
    for real_imgs, _ in loader:
        # Train Discriminator
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = G(z)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        # Discriminator loss
        real_loss = nn.BCELoss()(D(real_imgs), real_labels)
        fake_loss = nn.BCELoss()(D(fake_imgs.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        optimD.zero_grad()
        d_loss.backward()
        optimD.step()
        
        # Train Generator
        gen_loss = nn.BCELoss()(D(fake_imgs), real_labels)
        optimG.zero_grad()
        gen_loss.backward()
        optimG.step()