2
3

テキストから3Dのモデル生成を楽しめる Ooen AI の Point- E AIモデル。

Last updated at Posted at 2024-07-24

Gradioのインターフェースでプロンプトを選択できるようにしました。(英文テキストも入力可能。) 選択されたプロンプトに基づいてポイントクラウドが生成され、その結果がプロットとして表示されます。

このコードは、以下の手順を実行します:

Point-Eリポジトリをクローンし、必要なパッケージをインストールします。
モジュールをインポートします。
デバイス(CUDAまたはCPU)を設定します。
ベースモデルとアップサンプルモデルを作成し、チェックポイントをダウンロードします。
サンプラーを作成し、指定されたプロンプトに基づいてポイントクラウドを生成します。
生成されたポイントクラウドをプロットし、表示します。
このコードを Google corab GPU で実行すると、指定されたテキストプロンプトに基づいて3Dポイントクラウドが生成され、表示されます。

image.png

image.png

image.png

スクリーンショット 2024-07-25 062310.png

!git clone https://github.com/openai/point-e.git
%cd point-e
!pip install .
!pip install gradio

import torch
from tqdm.auto import tqdm
import gradio as gr

from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
from point_e.diffusion.sampler import PointCloudSampler
from point_e.models.download import load_checkpoint
from point_e.models.configs import MODEL_CONFIGS, model_from_config
from point_e.util.plotting import plot_point_cloud

# デバイスを設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ベースモデルを作成
print('ベースモデルを作成中...')
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

# アップサンプルモデルを作成
print('アップサンプルモデルを作成中...')
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

# ベースモデルのチェックポイントをダウンロード
print('ベースモデルのチェックポイントをダウンロード中...')
base_model.load_state_dict(load_checkpoint(base_name, device))

# アップサンプルモデルのチェックポイントをダウンロード
print('アップサンプルモデルのチェックポイントをダウンロード中...')
upsampler_model.load_state_dict(load_checkpoint('upsample', device))

# ポイントクラウドサンプラーを作成
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096 - 1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 0.0],
    model_kwargs_key_filter=('texts', ''), # アップサンプラーには条件付けしない
)

# サンプルプロンプト
sample_prompts = [
    'a red motorcycle',
    'a blue sports car',
    'a green spaceship',
    'a yellow submarine',
    'a pink robot',
    'a white cat',
    'a black dog',
    'a purple dragon',
    'a silver plane',
    'a golden castle'
]

# ポイントクラウドを生成する関数
def generate_point_cloud(prompt):
    # モデルからサンプルを生成
    samples = None
    for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))):
        samples = x
    pc = sampler.output_to_point_clouds(samples)[0]
    fig = plot_point_cloud(pc, grid_size=3, fixed_bounds=((-0.75, -0.75, -0.75),(0.75, 0.75, 0.75)))
    return fig

# Gradioインターフェースを作成
interface = gr.Interface(
    fn=generate_point_cloud,
    inputs=gr.components.Dropdown(choices=sample_prompts, label="プロンプトを選択"),
    outputs=gr.components.Plot(label="生成されたポイントクラウド")
)

# インターフェースを起動
interface.launch()

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3