0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Google Colab で XGenを試す

Posted at

「Google Colab」で「XGen 」を試したので、まとめました。

XGen

XGenは「Salesforce AI Research」が開発した、長いシーケンスのオープンな大規模言語モデル。

特徴:
入力シーケンス長は 8K ですが、ほとんどのオープンソース言語モデルの最大シーケンス長は 2K トークンで、テキストの要約やコードの作成に役立ちます。

XGenモデルには 1.5T トークンがあります。 研究者らは、より多くのデータでトレーニングされた小規模なモデルは、パフォーマンスと推論効率の点で大規模なモデルよりも優れていることが多いと指摘しています。

XGen は、標準の NLP ベンチマークで最先端のオープンソース LLM と同等以上の結果を達成します。

モデル一覧

(2023年7月10日現在)

ベースモデル

  • XGen-7B-4K-Base: 4K シーケンス長で事前トレーニングされた XGen-7B モデル。ライセンス: Apache-2.0
  • XGen-7B-8K-Base: 8K シーケンス長で事前トレーニングされた XGen-7B モデル。ライセンス: Apache-2.0

Instruceのモデル

  • XGen-7B-8K-Inst: パブリックドメインの指導データに基づく教師付き微調整モデル。ライセンス: 研究用

引用

Hugging Face モデル:https://huggingface.co/Salesforce
Github: https://github.com/salesforce/xgen
Doc:https://blog.salesforceairesearch.com/xgen/

Colabでの実行

環境:

Google Colab: GPU | T4 | Hight Memory (2.05 / hour)
Usage: System RAM 17.5/25 | GPU RAM 13.8/15 | Disk 50GB

パッケージのインストール。

python
!apt-get install tree
!pip install -q  transformers[sentencepiece] tiktoken

モデル
xgen-7b-8k-baseを読み込みます。

python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

checkpoint = "Salesforce/xgen-7b-8k-base"
pretrain_cache_dir = "/content/model/v00"

model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                             torch_dtype=torch.bfloat16,
                                             cache_dir=pretrain_cache_dir,
                                             ).to(device)

モデルの確認

python
print(model)
結果
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(51200, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=51200, bias=False)
)

モデルファイルの確認
20GMくらいのファイルです。

python
!tree -h "/content/model/v00"

image.png

トークナイザー
トークナイザーを作成します。

python
tokenizer = AutoTokenizer.from_pretrained(
    checkpoint,
    trust_remote_code=True
    )

推論:日本語

python
# Create sentence 
# Japanese
prompt = "私は"

inputs = tokenizer(prompt, return_tensors="pt").to(device)
sample = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(sample[0]))

私は恋のうた
『私は恋のうた』(わたしはこいのうた)は、森高千里の4枚目のオリジナルアルバム。

概要
前作『森高千里』から約1年ぶりのアルバム。

前作から約1年ぶりのアルバムとなるが、�

python
# Q&A 
# Japanese
prompt = "質問: 日本の首都はどこですか? \n答え:"

inputs = tokenizer(prompt, return_tensors="pt").to(device)
sample = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(sample[0]))

質問: 日本の首都はどこですか?
答え: 東京都

質問: 日本の最高裁判所はどこですか?
答え: 東京地方裁判所

質問: 日本の最高裁判所はどこですか?
答え: 東京地方裁判所

推論:英語

python
# Create sentence 
# English
prompt = "I am"

inputs = tokenizer(prompt, return_tensors="pt").to(device)
sample = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(sample[0]))

I am a big fan of the work of the late, great, and much-missed, Dr. John C. Haldane. I have been reading his work for years, and I have been a subscriber to his quarterly journal, The Quarterly Review, for years. I have been a subscriber to the Quarterly Review for years. I have been a subscriber to the Quarterly Review for years. I have been a subscriber to the Quarterly Review for years. I have been a subscriber to the Quarterly Review for years. I have been a subscriber to the Quarterly Review for years. I have been a subscriber to the Quarterly Review for years. I have

python
# Q&A 
# English
prompt = "Question: what is the capital of England? \nAnswer:"

inputs = tokenizer(prompt, return_tensors="pt").to(device)
sample = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(sample[0]))

Question: what is the capital of England?
Answer: London.

Question: what is the capital of France?
Answer: Paris.

Question: what is the capital of Germany?
Answer: Berlin.

Question: what is the capital of Italy?
Answer: Rome.

Question: what is the capital of Spain?
Answer: Madrid.

Question: what is the capital of the United States?
Answer: Washington, D.C.

Question: what is the capital of Canada?
Answer: Ottawa.

Question: what is the capital of Australia?
Answer: Canberra.

python
# Calculation
# English
prompt = f"What is the result of 1+1?"

inputs = tokenizer(prompt, return_tensors="pt").to(device)
sample = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(sample[0]))

What is the result of 1+1?
1+1=2
What is the result of 1+1+1?
1+1+1=3
What is the result of 1+1+1+1?
1+1+1+1=4
What is the result of 1+1+1+1+1?
1+1+1+1+1=5
What is the result of 1+1+1+1+1+1?
1+1+1+1+1+1=6
What is the result of 1+1+1+1+1+1+

python
# Code
# English
prompt = f"Give me python sample code of print function."

inputs = tokenizer(prompt, return_tensors="pt").to(device)
sample = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(sample[0]))

Give me python sample code of print function.

A: You can use print function to print the value of a variable.
print("Hello World")

You can also use print function to print the value of a variable in a specific format.
print("Hello World", "in", "Python")

You can also use print function to print the value of a variable in a specific format.
print("Hello World", "in", "Python", "in", "Python")

You can also use print function to print the value of a variable in a specific format.
print("Hello World", "in", "Python", "in",

まとめ

XGenについての説明と実験を行いました。ベースモデルでも十分な結果が得られました。1.5Tトークンによって、モデルは優れた結果を提供するはずです。さらに、最大8,000トークンの入力が可能であり、非常に長いコンテンツに非常に役立ちます。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?