0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

漫画家さんが、線画をAIで描く場合に想定される、便利な仕組み

Posted at

Edgeの抽出。線画生成に特化した機械学習モデルについ(基礎)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

# U-Net architecture for lineart generation
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Decoder (upsampling)
        self.dec4 = self.upconv_block(512, 256)
        self.dec3 = self.upconv_block(512, 128)
        self.dec2 = self.upconv_block(256, 64)
        self.dec1 = self.upconv_block(128, 64)
        
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)
        
    def conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
    
    def upconv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        
        # Decoder
        d4 = self.dec4(e4)
        d3 = self.dec3(torch.cat([d4, e3], dim=1))
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        
        return torch.sigmoid(self.final(d1))

# Custom dataset for manga images and linearts
class MangaLineartDataset(Dataset):
    def __init__(self, image_paths, lineart_paths, transform=None):
        self.image_paths = image_paths
        self.lineart_paths = lineart_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        lineart = Image.open(self.lineart_paths[idx]).convert('L')
        
        if self.transform:
            image = self.transform(image)
            lineart = self.transform(lineart)
        
        return image, lineart

# Training function
def train_lineart_model(model, train_loader, num_epochs, device):
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (images, linearts) in enumerate(train_loader):
            images, linearts = images.to(device), linearts.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, linearts)
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

# Generate lineart from an image
def generate_lineart(model, image_path, device):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        lineart = model(image)
    
    return transforms.ToPILImage()(lineart.squeeze(0).cpu())

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model
    model = UNet(in_channels=3, out_channels=1).to(device)
    
    # Prepare dataset and dataloader
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    
    # Assume you have lists of image and lineart paths
    image_paths = ["path/to/image1.jpg", "path/to/image2.jpg", ...]
    lineart_paths = ["path/to/lineart1.png", "path/to/lineart2.png", ...]
    
    dataset = MangaLineartDataset(image_paths, lineart_paths, transform=transform)
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # Train the model
    train_lineart_model(model, train_loader, num_epochs=50, device=device)
    
    # Save the trained model
    torch.save(model.state_dict(), "lineart_model.pth")
    
    # Generate lineart for a new image
    new_image_path = "path/to/new_image.jpg"
    generated_lineart = generate_lineart(model, new_image_path, device)
    generated_lineart.save("generated_lineart.png")

この実装は、線画生成に特化した機械学習モデルの一例です。主な特徴と説明は以下の通りです:

  1. モデルアーキテクチャ:

    • U-Net構造を採用しています。これは画像セグメンテーションタスクで広く使用されており、線画生成にも適しています。
    • エンコーダ部分で画像の特徴を抽出し、デコーダ部分で線画を再構築します。
    • スキップ接続により、高解像度の情報を保持しつつ、深い特徴抽出が可能です。
  2. データセット:

    • カスタムデータセットクラス MangaLineartDataset を定義しています。
    • 元の漫画画像とそれに対応する線画のペアを学習データとして使用します。
  3. 学習プロセス:

    • Binary Cross Entropy損失関数を使用しています。これは2値(白黒)の線画生成に適しています。
    • Adamオプティマイザを使用して、モデルのパラメータを最適化します。
  4. 線画生成:

    • 学習済みモデルに新しい画像を入力することで、線画を生成します。

このモデルを使用する際の注意点と改善のヒント:

  1. データセットの準備:

    • 高品質な漫画画像と対応する線画のペアを大量に収集する必要があります。
    • データ拡張(回転、反転、コントラスト調整など)を行うことで、モデルの汎化性能を向上させることができます。
  2. モデルの改良:

    • ResNetやDenseNetなどの最新のアーキテクチャを取り入れることで、性能を向上させることができます。
    • Attention機構を追加することで、線の重要な特徴により注目させることができます。
  3. 損失関数の工夫:

    • L1損失やPerceptual損失を組み合わせることで、より鮮明で正確な線画を生成できる可能性があります。
  4. 後処理:

    • 生成された線画に対して、閾値処理やモルフォロジー演算を適用することで、よりクリーンな結果を得ることができます。
  5. スタイル制御:

    • 条件付きGANの手法を取り入れることで、異なる線画スタイル(例:太い線、細い線、スケッチ風など)を制御できるようになります。

このモデルを実際に使用する場合、大量の学習データと計算リソースが必要になります。また、漫画家の個性や著作権に配慮しつつ、独自のデータセットを構築することが重要です。

さらに詳しい実装を考える: 「漫画家の個性的な線画のタッチ」

このタスクには、スタイル転送と生成モデルの組み合わせが適していると考えられます。
具体的には、条件付きGANを使用したアプローチを提案できます。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

# Generator Network
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, num_filters=64):
        super(Generator, self).__init__()
        self.down1 = self.conv_block(input_channels, num_filters, norm=False)
        self.down2 = self.conv_block(num_filters, num_filters * 2)
        self.down3 = self.conv_block(num_filters * 2, num_filters * 4)
        self.down4 = self.conv_block(num_filters * 4, num_filters * 8)

        self.up1 = self.upconv_block(num_filters * 8, num_filters * 4)
        self.up2 = self.upconv_block(num_filters * 8, num_filters * 2)
        self.up3 = self.upconv_block(num_filters * 4, num_filters)
        self.up4 = self.upconv_block(num_filters * 2, output_channels, norm=False)

    def conv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not norm)]
        if norm:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, True))
        return nn.Sequential(*layers)

    def upconv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm=True):
        layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=not norm)]
        if norm:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.ReLU(True))
        return nn.Sequential(*layers)

    def forward(self, x, style):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        # Inject style
        d4 = d4 * style.view(-1, 1, 1, 1)

        u1 = self.up1(d4)
        u2 = self.up2(torch.cat([u1, d3], dim=1))
        u3 = self.up3(torch.cat([u2, d2], dim=1))
        u4 = self.up4(torch.cat([u3, d1], dim=1))

        return torch.tanh(u4)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, input_channels, num_filters=64):
        super(Discriminator, self).__init__()
        self.conv1 = self.conv_block(input_channels, num_filters, norm=False)
        self.conv2 = self.conv_block(num_filters, num_filters * 2)
        self.conv3 = self.conv_block(num_filters * 2, num_filters * 4)
        self.conv4 = self.conv_block(num_filters * 4, num_filters * 8, stride=1)
        self.conv5 = nn.Conv2d(num_filters * 8, 1, kernel_size=4, stride=1, padding=1)

    def conv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm=True):
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not norm)]
        if norm:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, True))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

# Custom Dataset
class MangaArtistDataset(Dataset):
    def __init__(self, content_dir, style_dir, transform=None):
        self.content_paths = [os.path.join(content_dir, f) for f in os.listdir(content_dir)]
        self.style_paths = [os.path.join(style_dir, f) for f in os.listdir(style_dir)]
        self.transform = transform

    def __len__(self):
        return len(self.content_paths)

    def __getitem__(self, idx):
        content_img = Image.open(self.content_paths[idx]).convert('L')
        style_img = Image.open(np.random.choice(self.style_paths)).convert('L')

        if self.transform:
            content_img = self.transform(content_img)
            style_img = self.transform(style_img)

        return content_img, style_img

# Training function
def train(generator, discriminator, dataloader, num_epochs, device):
    criterion_gan = nn.MSELoss()
    criterion_pixel = nn.L1Loss()

    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        for i, (content_imgs, style_imgs) in enumerate(dataloader):
            content_imgs = content_imgs.to(device)
            style_imgs = style_imgs.to(device)

            # Train Generator
            optimizer_g.zero_grad()

            # Generate styled images
            style_vector = torch.mean(style_imgs.view(style_imgs.size(0), -1), dim=1)
            gen_imgs = generator(content_imgs, style_vector)

            # Adversarial loss
            pred_fake = discriminator(gen_imgs)
            loss_gan = criterion_gan(pred_fake, torch.ones_like(pred_fake))

            # Pixel-wise loss
            loss_pixel = criterion_pixel(gen_imgs, style_imgs)

            # Total loss
            loss_g = loss_gan + 100 * loss_pixel

            loss_g.backward()
            optimizer_g.step()

            # Train Discriminator
            optimizer_d.zero_grad()

            # Real loss
            pred_real = discriminator(style_imgs)
            loss_real = criterion_gan(pred_real, torch.ones_like(pred_real))

            # Fake loss
            pred_fake = discriminator(gen_imgs.detach())
            loss_fake = criterion_gan(pred_fake, torch.zeros_like(pred_fake))

            # Total loss
            loss_d = 0.5 * (loss_real + loss_fake)

            loss_d.backward()
            optimizer_d.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
                      f"Loss D: {loss_d.item():.4f}, Loss G: {loss_g.item():.4f}")

    return generator

# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize models
    generator = Generator(input_channels=1, output_channels=1).to(device)
    discriminator = Discriminator(input_channels=1).to(device)

    # Prepare dataset and dataloader
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    dataset = MangaArtistDataset("path/to/content_images", "path/to/style_images", transform=transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    # Train the model
    trained_generator = train(generator, discriminator, dataloader, num_epochs=100, device=device)

    # Save the trained model
    torch.save(trained_generator.state_dict(), "manga_artist_style_transfer.pth")

# Function to generate styled lineart
def generate_styled_lineart(generator, content_image_path, style_image_path, device):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    content_img = Image.open(content_image_path).convert('L')
    style_img = Image.open(style_image_path).convert('L')

    content_tensor = transform(content_img).unsqueeze(0).to(device)
    style_tensor = transform(style_img).unsqueeze(0).to(device)

    style_vector = torch.mean(style_tensor.view(style_tensor.size(0), -1), dim=1)

    generator.eval()
    with torch.no_grad():
        styled_lineart = generator(content_tensor, style_vector)

    styled_lineart = styled_lineart.squeeze().cpu().numpy()
    styled_lineart = (styled_lineart + 1) / 2.0 * 255.0
    styled_lineart = styled_lineart.astype(np.uint8)

    return Image.fromarray(styled_lineart)

# Example usage
# styled_lineart = generate_styled_lineart(trained_generator, "path/to/content.jpg", "path/to/style.jpg", device)
# styled_lineart.save("styled_lineart.png")

このコードは、条件付きGANを使用して特定の漫画家のスタイルを学習し、そのスタイルで新しい線画を生成するモデルを実装しています。主な特徴と説明は以下の通りです:

  1. アーキテクチャ:

    • Generator: U-Net風の構造を持ち、スタイル情報を中間層に注入します。
    • Discriminator: PatchGANスタイルの判別器で、生成された画像が本物のスタイルに近いかを判定します。
  2. スタイル表現:

    • スタイル画像を平均化してベクトルとして表現し、Generatorに注入します。これにより、スタイルの全体的な特徴を捉えつつ、柔軟な生成が可能になります。
  3. 損失関数:

    • Adversarial Loss: 生成された画像が本物らしく見えるようにします。
    • Pixel-wise Loss: 生成された画像がスタイル画像に近づくようにします。
  4. データセット:

    • コンテンツ画像(一般的な線画)とスタイル画像(特定の漫画家の線画)のペアを使用します。
  5. 学習プロセス:

    • GeneratorとDiscriminatorを交互に学習させ、バランスを取りながら性能を向上させます。
  6. 推論:

    • 学習済みのGeneratorを使用して、新しいコンテンツ画像に対して特定のスタイルを適用します。

今後の展望に向けての考慮

このアプローチの利点と注意点を整理

利点:

  1. 柔軟性:異なる漫画家のスタイルを学習し、様々なスタイルで線画を生成できます。
  2. 細部の保持:U-Net構造により、元の画像の細かい特徴を保持しつつスタイルを転送できます。
  3. 一貫性:一度学習すれば、同じスタイルを複数の画像に適用できます。

注意点:

  1. データ収集:特定の漫画家のスタイルを十分に学習するには、大量の高品質なサンプルが必要です。
  2. 著作権:他の漫画家のスタイルを模倣する際は、法的・倫理的な配慮が必要です。
  3. 計算リソース:GANの学習には高性能なGPUと長い学習時間が必要です。
  4. ハイパーパラメータ調整:最適な結果を得るには、慎重なパラメータ調整が必要です。

改善の余地(本番プロジェクトに向けた作業支援のエリア):

  1. スタイル表現の強化:スタイルを単純な平均ベクトルではなく、より複雑な特徴として捉える方法を検討できます。
  2. マルチスケール処理:異なる解像度でのスタイル転送を組み合わせることで、より細かい制御が可能になります。
  3. 注意機構:Attentionメカニズムを導入することで、スタイルの適用をより細かく制御できる可能性があります。

参考サイト

基本のプロセスの理解
https://comic.smiles55.jp/guide/20374/

【無料】AI漫画が作成できるサイト・アプリおすすめ5選【2024年最新版】
https://www.perfectcorp.com/ja/consumer/blog/generative-AI/app-to-create-AI-manga

生成AIで漫画ができる!? おすすめのツールと具体的なやり方を解説
https://team-henshin.com/media/ai/ai-manga/

AI漫画制作特化のキャンパスツール「Anifusion」使い方/料金解説!日本語対応!
https://ai-henoheno-mohero.com/anifusion-start/

Clip Studio EX
https://note.com/sirasira/n/ne5219319a645

動画編集・イラスト未経験の私が、生成AIツールを駆使したら簡単にAI漫画を作れてしまった件
https://weel.co.jp/media/innovator/ai-manga/

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?