6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

HuggingFaceバージョンのDollyを試してみる

Last updated at Posted at 2023-03-31

HuggingFaceでDollyの事前学習済みモデルが公開されました。モデルのファインチューニングなしにすぐに利用できます。モデルのレスポンスを試すことができます。

なお、GitHubの方もレスポンスを返すようになっています。

Databricksノートブックで試してみます。こちらの手順に従います。

途中エラーが出る場合には以下を実行してaccelerateをインストールします。

%pip install accelerate
Python
import numpy as np
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer
)

tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v1-6b", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v1-6b", device_map="auto", trust_remote_code=True)
Python
PROMPT_FORMAT = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:
"""

def generate_response(instruction: str, *, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, 
                      do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs) -> str:
    input_ids = tokenizer(PROMPT_FORMAT.format(instruction=instruction), return_tensors="pt").input_ids.to("cuda")

    # each of these is encoded to a single token
    response_key_token_id = tokenizer.encode("### Response:")[0]
    end_key_token_id = tokenizer.encode("### End")[0]

    gen_tokens = model.generate(input_ids, pad_token_id=tokenizer.pad_token_id, eos_token_id=end_key_token_id,
                                do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)[0].cpu()

    # find where the response begins
    response_positions = np.where(gen_tokens == response_key_token_id)[0]

    if len(response_positions) >= 0:
        response_pos = response_positions[0]
        
        # find where the response ends
        end_pos = None
        end_positions = np.where(gen_tokens == end_key_token_id)[0]
        if len(end_positions) > 0:
            end_pos = end_positions[0]

        return tokenizer.decode(gen_tokens[response_pos + 1 : end_pos]).strip()

    return None

# Sample similar to: "Excited to announce the release of Dolly, a powerful new language model from Databricks! #AI #Databricks"
generate_response("Write a tweet announcing Dolly, a large language model from Databricks.", model=model, tokenizer=tokenizer)

以下のようなレスポンスが返ってきます。

Out[2]: 'Watch Dolly make your language learning more fun! Join over 100 million users today with the largest language model. 🔥 https://t.co/U6zlNfM6nV #Databricks #DollyLanguageModel'

日本語でも返ってきます。中身は若干微妙ですが。でも、すごい。

Python
generate_response("Databricksとは何ですか?", model=model, tokenizer=tokenizer)
Out[3]: 'Databricks とは、人工知能 (AI)、機械学習、リアルタイム、ビッグデータ、そして Hadoop の全てを集約するパブリックプラットフォームを提供します。'
Python
generate_response("人工知能とは何ですか?", model=model, tokenizer=tokenizer)
Out[7]: '人工知能とは、人工知能を活用したコンピューティングの最新技術です。人間の脳のようなものに抽象的な学習アルゴリズムや機械学習を使い、自然言語を理解し、知性を手に入れるということです。'

徐々にできることが増えていっています。

Databricksクイックスタートガイド

Databricksクイックスタートガイド

Databricks無料トライアル

Databricks無料トライアル

6
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?