LoginSignup
1
0

AI SDKとDALL·EおよびStable Diffusion 2を使った画像生成

Last updated at Posted at 2024-03-26

Vercel AI SDKを用いて、DALL·E 3DALL·E 2、およびStable Diffusion 2による画像生成を行う方法について詳しく説明します。動作するデモアプリ(Mulai3)はこちらです。「○○についての画像を描いて」のように尋ねてみてください。

もう一つのデモアプリ(Mulai)では、明示的にDALL·E 3またはDALL·E 2Stable Diffusion 2のモデルを選んで、作ってもらいたい画像をそのまま入力してください。

これらのアプリケーションや関数呼び出しの概要についてはこちらの記事もご参照ください。

関数呼び出しを用いた画像生成 (Mulai3)

参照記事でもあったように、AI SDKでは任意の関数呼び出しをオブジェクトへの引数として定義でき、サンプルアプリケーションでは下記のようなものを宣言しています。descriptionpromptの条件に合致したとAIモデル(GPT-3.5/GPT-4など)が判断した時に、renderに定義されている関数が呼び出されます。

ai-action.ts
        generate_images: {
          description: 'Generate images based on the given prompt',
          parameters: z.object({
            prompt: z.string().describe('the image description to be generated'),
          }),
          render: async function* ({prompt}:{prompt:string}) {
            try {
              const models:ChatModel[] = [
                'dall-e-2', 'dall-e-3', 'stable-diffusion-2',
              ].map((value) => getModelByValue(value) as ChatModel)

              yield (
                <Card className="m-1 p-3">
                  <CardContent className="flex flex-row flex-wrap gap-3 justify-center">
                    {models.map((model) => {
                      const generatingTitle = `${model.label} generating an image: ${prompt}`;
                      return (<div key={model.modelValue} title={generatingTitle} className='size-64 border animate-pulse grid place-content-center place-items-center gap-3'>
                        <div className="rounded-3xl bg-slate-200 size-24 mx-auto"></div>
                        <div className="rounded w-32 h-4 bg-slate-200 text-center font-bold">{model.label}</div>
                        <div className="rounded w-32 max-h-16 p-1 bg-slate-200 overflow-hidden">{/* FIXME i18n */}Generating: {prompt}</div>
                      </div>)
                    })}
                  </CardContent>
                </Card>
              )
              
              const results = await Promise.all(
                models.map((model) => generateImages(prompt, model)))
              const images = results.flat()

              aiState.done({
                ...aiState.get(),
                messages: [
                  ...aiState.get().messages,
                  {
                    role: "function",
                    name: "generate_images",
                    content: `image prompt: ${prompt}`,
                  },
                ]
              });

              return (
                <Card className="m-1 p-3">
                  <CardContent className="flex flex-row flex-wrap  gap-3 justify-center">
                    {images.map((image) => {
                      const title = getModelByValue(image.model)!.label + ': ' + (image.revised_prompt ?? prompt)
                      return (<Image key={image.url} src={image.url!} title={title} alt={title} width={256} height={256} className='size-64 border' />)
                    })}
                  </CardContent>
                </Card>
              )
            } catch (e:any) {
              console.log('got error', e, prompt);
              aiState.done({
                ...aiState.get(),
                messages: [
                  ...aiState.get().messages,
                  {
                    role: "function",
                    name: "generate_images",
                    content: e.toString(),
                  },
                ]
              });
              return <span>{e.toString()}</span>                
            }
          }
        } as any,

このアプリケーションで実施していることをもう少し詳しく解説します。renderには唯一の引数promptのみが渡されます。このpromptは、ユーザーが入力した文字列がそのまま渡されるとは限らず、AIモデルによって、入力した内容が加工されて引き渡されることがあります。

このアプリケーションでは、dall-e-2dall-e-3およびstable-diffusion-2の3つのモデルを用いて画像生成を試みます。いずれも生成には時間が掛かるため、yieldにより、生成中に表示すべきコンテンツを定義しています。基本的には生成後の表示と似たような枠を書き、画像の代わりに、Tailwind CSSのroundedクラスを使って、仮の表示を行っています。

続いて、modelNamesそれぞれに対してgenerateImages関数を呼び出して画像生成を行っています。generateImages関数の説明は後から行うとして、Promise.allを呼び出し、両方の画像が生成し終わったタイミングで、生成完了となるようにしています。生成された画像データはimages配列に入れています。

AI SDKではaiStateにてメッセージ履歴を管理しますが文字コンテンツにしか対応していないため、プロンプト文字列を格納しておきます。

最後に出力部分です。imageオブジェクトからmodelurlおよびrevised_promptを使用しています。revised_promptdall-e-3が引き渡されたpromptとやや異なるプロンプトを利用して画像生成する時に用いられます。dall-e-2にはこの仕組みはないようです。

それでは続いて、実際にdall-estable diffusionのAPIを呼び出して画像生成している部分の関数を説明します。こちらはOpenAIとStable DiffusionがホストされているHuggingFaceとで、処理を呼び分けています。

ai-action.ts
import { OpenAI } from "openai"; 
import { HfInference } from '@huggingface/inference';

const openai = new OpenAI({
  apiKey: process.env.OPENAI_API_KEY,
});
const Hf = new HfInference(process.env.HUGGINGFACE_API_KEY);

type GenerateImageResponse = {
  url: string,
  model: string,
  revised_prompt?: string,
}

async function generateImages(prompt:string, model:ChatModel):Promise<GenerateImageResponse[]> {
  if (model.provider === 'openai-image')
    return generateOpenaiImages(prompt, model)
  else if (model.provider === 'huggingface-image')
    return generateHuggingFaceImage(prompt, model)
  else
    throw new Error(`unexpected provider: ${model.provider}`)
}

async function generateOpenaiImages(prompt:string, model:ChatModel):Promise<GenerateImageResponse[]> {
  const modelValue = model.modelValue
  const baseParams:ImageGenerateParams = { prompt: prompt, response_format: 'url' }
  const e2Params:ImageGenerateParams = { ...baseParams, model: 'dall-e-2', size: '256x256' }
  const e3Params:ImageGenerateParams = { ...baseParams, model: "dall-e-3", size: '1024x1024' }
  const params = modelValue == 'dall-e-3' ? e3Params : e2Params

  const responseImage = await openai.images.generate(params);
  const data = responseImage.data.map((image) => ({url: image.url as string, model: modelValue, revised_prompt: image.revised_prompt}))
  return data
}

async function generateHuggingFaceImage(prompt:string, model:ChatModel):Promise<GenerateImageResponse[]> {
  const blob:Blob = await Hf.textToImage({
      inputs: prompt
  })

  const arrayBuffer = await blob.arrayBuffer();
  const base64 = abToBase64(arrayBuffer)
  const url = `data:${blob.type};base64,` + base64

  return [{url, model: model.modelValue}]
}

function abToBase64(arrayBuffer:ArrayBuffer) {
  const base64String = Buffer.from(arrayBuffer).toString('base64');
  return base64String;
}

generateImagesでは、まず引数に渡されたモデルのプロバイダーに応じてgenerateOpenaiImagesgenerateHuggingFaceImageのいずれかを呼び出します。

OpenAIのdall-eのモデルでは、基本的にopenai.images.generateを呼び出せば良く、引数としてpromptresponse_formatmodelsizeを渡せばOKです。モデルによってサポートする画像サイズが異なるので注意してください。表示時の参考用に、戻り値にmodelを含めています。

その他引数はOpenAIのドキュメントにて説明されています。必要に応じてこれらの引数を受け付ける関数も作成可能でしょう。

Hugging Face経由で呼び出すstable diffusionモデルでは、Hf.textToimageという関数を使います。引数やモデル値は少々違うため、調整しています。Hf.textToImageでは生成したURLを戻すことはできず、常にBlobデータが戻ってくるため、これをBase64に変換してdata URI形式にしたあと、Markdownに変換します。

APIの直接呼び出しによる画像生成 (Mulai)

Mulaiでは、チャットAPIを呼び出すのと類似したインターフェースにて画像生成を行っています。messagesにはこれまでのやり取りの配列が渡されているので、最後のメッセージを取得してpromptとします。

api/chat/route.ts
const openaiImageStream:ChatStreamFunction = async({model, messages}) => {
    const prompt = messages[messages.length - 1].content

    const params:ImageGenerateParams = {
        prompt,
        model: model.sdkModelValue,
        n: 1,
        response_format: 'url',
    }
    const response = await openai.images.generate(params)

    const responseMarkdown = response.data.map((datum) => 
        datum.url ? imageMarkdown(datum.url as string, prompt) : ''
    ).join('\n')

    const stream = stringToReadableStream(responseMarkdown)
    return stream    
}

こちらもopenai.images.generateを呼び出して画像生成しています。nは一度に生成する画像の数ですが、無料枠では1分間に5つまでしか画像生成できないことに注意してください。

チャットと同じインターフェースであり、このアプリケーションでは<img>タグには対応していないため、戻り値をMarkdown形式に加工しています。

function imageMarkdown(url:string, prompt:string = 'Image') {
    // [] => (), " => '
    const escapedPrompt = prompt.replaceAll(/\[/g, "(").replaceAll(/\]/g, ")").replaceAll(/"/g, "'")
    const responseMarkdown = `![${escapedPrompt}](${url} "${escapedPrompt}")`
    return responseMarkdown
}

呼び出しは下記のように行っています。chatStreamFactoryの戻り値はopenaiImageStreamです。

const responseStreamGenerator = chatStreamFactory(modelData)
const stream = await responseStreamGenerator({model:modelData, messages: m})
return new StreamingTextResponse(stream)

もう一つ、HuggingFace上のStable Diffusion 2による画像生成にも対応しています。dall-eのケースと同様に引数からpromptを取得し、Hf.textToImageを呼び出しています。関数インターフェースはmodelを受け取るようになっていますが、Hf.textToImageは常にStable Diffusion 2を用いて画像生成を行います。

route.ts
import { HfInference } from '@huggingface/inference';

const Hf = new HfInference(process.env.HUGGINGFACE_API_KEY);


const huggingFaceImageStream:ChatStreamFunction = async ({model, messages}) => {
    const prompt = messages[messages.length - 1].content

    const blob:Blob = await Hf.textToImage({
        inputs: prompt
    })

    const arrayBuffer = await blob.arrayBuffer();
    const base64 = abToBase64(arrayBuffer)
    const url = `data:${blob.type};base64,` + base64

    const responseMarkdown = imageMarkdown(url, prompt)

    return stringToReadableStream(responseMarkdown)
}

Hf.textToImageの戻り値をdata URI形式にしたあと、Markdownに変換します。imageMarkdownは上記と同じ処理です。

以上により、Stable Diffusionを使った画像生成も可能になります。関数呼び出し形式でも、同様の関数を使用することにより、Stable Diffusionの画像も比較的容易に生成できるようになると思います。

今回デモに使ったMulai3/Mulaiの最新のソースコードは下記よりご参照ください。

今回解説に使ったサンプルアプリケーションの概要についてはこちらの記事もご参照ください。


今回の記事はいかがだったでしょうか。こちらのサンプルアプリケーションに関する記事は、タグMulaiをご利用ください。

1
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0