LoginSignup
7
4

More than 1 year has passed since last update.

【Swift】iOSでStableDiffusionを使ってみた

Last updated at Posted at 2022-12-08

はじめに

AppleからStableDiffusionのCoreML変換バージョンが公開されました。
StableDiffusionに興味があったので使ってみました。

しかし、READMEに記載されている方法だと貧弱メモリのMacBookではできなかったので、貧弱MacBookでもできる方法を紹介します。

サンプルアプリ

今回のサンプルアプリはこちらです。

以下の設定ができます。

  • Prompt
  • Image Count
  • Step Count
  • Seed
    スクリーンショット 2022-12-08 22.03.05.png

ちなみに「StepCount」を50にすると生成までに1時間くらいかかります笑

この例では「StepCount」は5で実行しています。
完成まで10分かかりました。

0 1 2 3 4 5
simulator_screenshot_2E0395A8-23DA-437E-9539-7A98A83F372F.png simulator_screenshot_4FE43A8A-8B2F-43DC-B9DE-6F3B7FEBF611.png simulator_screenshot_78CF14C5-7A9B-4E50-8140-ECE54F21107C.png simulator_screenshot_E8D74D40-3147-4580-A586-FCD09CA6B287.png simulator_screenshot_7277CE6C-3091-42FB-812C-F73126A148CA.png simulator_screenshot_BCE30048-FFDC-43AE-BFD6-E1CB7B80354D.png

完成した画像!!
スクリーンショット 2022-12-08 21.56.48.png

準備

アクセストークンを生成

Hugging Faceのアカウントを作成します
スクリーンショット 2022-12-06 18.37.10.png
メールアドレスで認証を行います。
スクリーンショット 2022-12-06 18.39.43.png

こちらからユーザーアクセストークンを発行します。
スクリーンショット 2022-12-06 18.41.31.png

「Name」と「Role」を指定してトークンを生成します。
スクリーンショット 2022-12-06 18.46.23.png

Hugging Face APIのセットアップ

先ほどの仮想環境内かつ、ml-stable-diffusionのディレクトリ内で以下のコマンドを実行します。

ターミナル
huggingface-cli login

スクリーンショット 2022-12-06 18.49.19.png
上記の画像のようにトークンの入力が求められるので先ほど生成したトークンを入力します。

セットアップが成功すると以下のようになります。
スクリーンショット 2022-12-06 18.50.33.png

コンパイル済みのモデルをダウンロード

リポジトリをクローンします。

ターミナル
cd デスクトップのパス
git clone https://github.com/apple/ml-stable-diffusion.git
cd デスクトップのパス/ml-stable-diffusion

自分のMacBookではコンパイルできないため、appleにコンパイルされたモデルをダウンロードします。
今の所、この3つがあります。

  • apple/coreml-stable-diffusion-v1-4
  • apple/coreml-stable-diffusion-v1-5
  • apple/coreml-stable-diffusion-2-base
from huggingface_hub import snapshot_download
from huggingface_hub.file_download import repo_folder_name
from pathlib import Path
import shutil

# ダウンロードしたいモデルバージョンを指定
repo_id = "apple/coreml-stable-diffusion-v1-4"
#repo_id = "apple/coreml-stable-diffusion-v1-5"
#repo_id = "apple/coreml-stable-diffusion-2-base"
variant = "original/compiled"

def download_model(repo_id, variant, output_dir):
    destination = Path(output_dir) / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
    if destination.exists():
        raise Exception(f"Model already exists at {destination}")

    downloaded = snapshot_download(repo_id, allow_patterns=f"{variant}/*", cache_dir=output_dir)
    downloaded_bundle = Path(downloaded) / variant
    shutil.copytree(downloaded_bundle, destination)

    cache_folder = Path(output_dir) / repo_folder_name(repo_id=repo_id, repo_type="model")
    shutil.rmtree(cache_folder)
    return destination

model_path = download_model(repo_id, variant, output_dir="./models")
print(f"Model downloaded at {model_path}")

先ほど作成したファイルを実行します。

ターミナル
python3 作成したPythonファイルのパス

ダウンロードが成功したらcoreml-stable-diffusion-v1-4_original_compiledという名前のフォルダが作成されます。
スクリーンショット 2022-12-08 12.34.51.png

実装

ライブラリをインポート

SPMでライブラリをインポートします。
Branchのmainを設定して「Add Package」を選択します。
スクリーンショット 2022-12-08 12.42.19.png

StableDiffusionを選択して「Add Package」を選択します。
スクリーンショット 2022-12-08 12.43.18.png

モデルのインポート

ダウンロードしたモデルをcoreml-stable-diffusion-v1-4_original_compiledごと、プロジェクトにドラッグします。
以下の設定で「Finish」を選択します。
スクリーンショット 2022-12-08 12.38.47.png
このような形になります。
スクリーンショット 2022-12-08 12.40.05.png

コード

View

ContentView
import SwiftUI

struct ContentView: View {
    @StateObject private var viewModel = ViewModel()
    var body: some View {
        ScrollView(.vertical, showsIndicators: false) {
            Text(viewModel.status.message)
                .foregroundColor(.primary)
                .font(.system(size: 25, weight: .bold))
            if viewModel.status == .generateStart {
                ProgressView(value: .init(Double(viewModel.step)), total: .init(Double(viewModel.stepCount))) {
                    Text("生成中")
                } currentValueLabel: {
                    Text("(\(viewModel.step)/\(Int(viewModel.stepCount)))")
                }
                .progressViewStyle(CircularProgressViewStyle())
            }
            if let image = viewModel.image {
                slideImageView(image: image)
            } else {
                Text("画像を生成してください")
                    .frame(maxWidth: .infinity, minHeight: 300)
                    .background(Color.gray.opacity(0.3))
                    .cornerRadius(10)
            }
            GroupBox("Prompt") {
                TextField("プロンプトを入力してください", text: $viewModel.prompt, axis: .vertical)
                    .textFieldStyle(.roundedBorder)
                    .disabled(viewModel.status == .generateStart)
            }
            GroupBox("Image Count") {
                Stepper(value: $viewModel.imageCount, in: 1...4) {
                    Text("\(viewModel.imageCount)")
                }
                .disabled(viewModel.status == .generateStart)
            }
            GroupBox("Step Count") {
                Text(Int(viewModel.stepCount).description)
                    .frame(maxWidth: .infinity, alignment: .leading)
                Slider(value: $viewModel.stepCount, in: 1...50, step: 1)
                    .disabled(viewModel.status == .generateStart)
            }
            GroupBox("Seed") {
                Text(Int(viewModel.seed).description)
                    .frame(maxWidth: .infinity, alignment: .leading)
                Slider(value: $viewModel.seed, in: 0...1000, step: 1)
                    .disabled(viewModel.status == .generateStart)
            }
            Button {
                Task.init {
                    await viewModel.generateImage()
                }
            } label: {
                Text("生成")
                    .foregroundColor(.white)
                    .font(.system(size: 25, weight: .bold))
                    .frame(maxWidth: .infinity, minHeight: 70)
            }
            .frame(maxWidth: .infinity, minHeight: 70)
            .background(viewModel.status != .loadFinish || viewModel.status == .generateStart || viewModel.prompt == "" ? Color.secondary : Color.blue)
            .cornerRadius(10)
            .disabled(viewModel.status != .loadFinish || viewModel.status == .generateStart || viewModel.prompt == "")
        }
        .padding(.all)
        .task {
            Task.init {
                await viewModel.loadModels()
            }
        }
    }

    private func slideImageView(image: [CGImage]) -> some View {
        TabView {
            ForEach(0..<image.count, id: \.self) { index in
                Image(image[index], scale: 1.0, label: Text("生成画像"))
                    .resizable()
                    .scaledToFit()
            }
        }
        .tabViewStyle(.page)
        .indexViewStyle(PageIndexViewStyle(backgroundDisplayMode: .interactive))
        .frame(maxWidth: .infinity, minHeight: 400)
        .background(Color(uiColor: .secondarySystemFill))
        .cornerRadius(10)
    }
}

ViewModel

ViewModel
import Foundation
import StableDiffusion
import CoreGraphics

final class ViewModel: ObservableObject {
    @Published var pipeline: StableDiffusionPipeline?
    @Published var image: [CGImage]?
    @Published var prompt: String = "cat"
    @Published var imageCount: Int = 1
    @Published var stepCount: Double = 30
    @Published var seed: Double = 500
    @Published var step: Int = 0
    @Published var status: StableDiffusionStatus = .ready

    func loadModels() async {
        guard let resourceURL = Bundle.main.resourceURL else { return }
        do {
            Task.detached { @MainActor in
                self.status = .loadStart
            }
            let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL)
            Task.detached { @MainActor in
                self.pipeline = pipeline
                self.status = .loadFinish
            }
        } catch {
            Task.detached { @MainActor in
                self.status = .error
            }
        }
    }

    func generateImage() async {
        do {
            Task.detached { @MainActor in
                self.image = nil
                self.status = .generateStart
            }
            let image = try self.pipeline?.generateImages(
                prompt: self.prompt,
                imageCount: self.imageCount,
                stepCount: Int(self.stepCount),
                seed: Int(self.seed),
                disableSafety: true
            ) { data in
                Task.detached { @MainActor in
                    self.image = data.currentImages.map({ cgImage in
                        guard let cgImage else { fatalError() }
                        return cgImage
                    })
                    self.step = data.step
                }
                return true
            }.map({ cgImage in
                guard let cgImage else { fatalError() }
                return cgImage
            })
            Task.detached { @MainActor in
                self.image = image
                self.status = .generateFinish
            }
        } catch {
            Task.detached { @MainActor in
                self.status = .error
            }
        }
    }
}

Model

StableDiffusionStatus
import Foundation

enum StableDiffusionStatus {
    case ready
    case loadStart
    case loadFinish
    case generateStart
    case generateFinish
    case error

    var message: String {
        switch self {
        case .ready:
            return "準備中"
        case .loadStart:
            return "モデルの読み込みを開始しました"
        case .loadFinish:
            return "モデルの読み込みが終了しました"
        case .generateStart:
            return "画像生成を開始しました"
        case .generateFinish:
            return "画像生成が終了しました"
        case .error:
            return "エラーが発生しました"
        }
    }
}

おわり

一応ちゃんと動いてますが、非同期処理のところがおかしいので、誰か教えてください笑

もし直していただける優しい方がいましたらPRください🙏

7
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4