3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AMD GPU(ROCm)でDeepSeek-OCRを動かしてみた【Windows 11環境】

Last updated at Posted at 2025-10-22

AMD GPU(ROCm)でDeepSeek-OCRを動かしてみた【Windows 11環境】

はじめに

2025年10月20日にDeepSeekがリリースした最新のOCRモデル「DeepSeek-OCR」を、AMD Radeon GPU環境で動作させることに成功したので、その手順と結果を共有します。

公式ドキュメントではNVIDIA GPU + CUDAが前提となっており、AMD GPU(特にWindows + ROCm環境)での動作報告はほとんど見かけませんでしたが、いくつかの工夫で動作させることができました。

動作環境

OS: Windows 11 Pro (Build 26200)
GPU: AMD Radeon 8060S Graphics
CPU: AMD Ryzen AI MAX+ 395
Python: 3.12.10
PyTorch: 2.8.0a0+gitfc14c65 (ROCm 6.4.50101)
transformers: 4.46.3

遭遇した問題

問題1: Flash Attention 2への依存

DeepSeek-OCRは公式でFlash Attention 2を要求していますが、これはNVIDIA GPU専用です。

解決策: _attn_implementation='eager'に変更して回避。

問題2: 分散学習モジュールのエラー

ModuleNotFoundError: No module named 'torch._C._distributed_c10d'

PyTorch ROCm開発版に分散学習用のC++拡張が不完全なため発生。

解決策: transformersの分散学習チェックをモンキーパッチで無効化。

動作確認コード

以下のコードでAMD GPU環境での動作に成功しました:

# test_ocr_amd.py
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"

import torch

# ===== モンキーパッチ: 分散学習チェックを無効化 =====
def dummy_is_fsdp_managed_module(module):
    return False

def dummy_is_deepspeed_zero3_enabled():
    return False

import transformers.integrations.fsdp
import transformers.integrations
transformers.integrations.fsdp.is_fsdp_managed_module = dummy_is_fsdp_managed_module
transformers.integrations.is_deepspeed_zero3_enabled = dummy_is_deepspeed_zero3_enabled
# ===== モンキーパッチ終了 =====

from transformers import AutoModel, AutoTokenizer
import glob

print("PyTorchバージョン:", torch.__version__)
print("CUDA利用可能:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("デバイス名:", torch.cuda.get_device_name(0))

model_name = 'deepseek-ai/DeepSeek-OCR'

print("\nトークナイザーをロード中...")
tokenizer = AutoTokenizer.from_pretrained(
    model_name, 
    trust_remote_code=True
)

print("モデルをロード中...")
model = AutoModel.from_pretrained(
    model_name, 
    _attn_implementation='eager',  # Flash Attentionを無効化
    trust_remote_code=True, 
    use_safetensors=True,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)
print("✓ モデルロード成功\n")

model = model.eval()

# OCR実行
image_file = 'test.png'
prompt = "<image>\n<|grounding|>Convert the document to markdown."

print(f"OCR実行中: {image_file}")
res = model.infer(
    tokenizer, 
    prompt=prompt, 
    image_file=image_file, 
    output_path="./output", 
    base_size=640,        # Smallモード
    image_size=640, 
    crop_mode=False,
    save_results=True, 
    test_compress=False
)

# 結果を表示
print("\n" + "="*60)
print("OCR完了!")
print("="*60)

output_files = glob.glob("./output/*.txt")
if output_files:
    print(f"\n保存されたファイル: {len(output_files)}")
    with open(output_files[0], 'r', encoding='utf-8') as f:
        content = f.read()
        print(f"\n【結果プレビュー(最初の500文字)】")
        print(content[:500])
        print(f"\n... (合計 {len(content)} 文字)")

print("\n✓ 成功!")

実行結果

テスト画像

医療ガイドライン文書(DVT/肺塞栓に関する日本語2カラムレイアウト)を使用しました。

OCR結果の評価

✅ 成功した点:

  • 日本語認識: 医療専門用語を含む複雑な日本語を高精度で認識
  • レイアウト認識: 2カラムレイアウトを正確に解析
  • 表の構造化: Wellsスコアの表をHTMLテーブル形式で出力
  • 座標情報: 各テキストブロックの位置座標も取得

出力例:

<|ref|>sub_title<|/ref|><|det|>[[91, 231, 201, 250]]<|/det|>
## 2.1.2 DVTの検査前臨床的標準

<|ref|>text<|/ref|><|det|>[[91, 256, 490, 487]]<|/det|>
問診、診察で得られる個々の所見では、DVTを診断することは困難である。
このため、いくつかの所見の組み合わせでDVTの確率を推定するスコアが考案され...

<|ref|>table<|/ref|><|det|>[[230, 530, 763, 775]]<|/det|>
<table>
<tr><td>活動性がん(6ヵ月以内治療や緩和的治療を含む)</td><td>1</td></tr>
<tr><td>下肢の完全麻痺、不全麻痺あるいは最近のギプス装置による固定</td><td>1</td></tr>
...
</table>

⚠️ 認識精度:

  • 一部の文字に誤認識あり(約95%程度の精度)
  • 複雑な表の罫線は完全には再現されない
  • 全体としては実用レベル一歩手前の品質

処理時間

  • モデルロード: 約11分(初回ダウンロード含む、6.67GB)
  • OCR処理: A41枚テキストが数分程度

まとめ

できたこと

  • ✅ AMD Radeon GPU(ROCm)でDeepSeek-OCRが動作
  • ✅ Windows 11環境での動作確認
  • ✅ 日本語文書の高精度OCR
  • ✅ 複雑なレイアウト(2カラム、表)の認識

注意点

  • ⚠️ 公式環境(NVIDIA GPU + Linux)と比べて若干のハックが必要
  • ⚠️ Flash Attention無効化により、処理速度は公称値より遅い可能性
  • ⚠️ PyTorch ROCm開発版特有の問題に対処が必要

今後の展望

  • より多様な文書でのテスト
  • 処理速度の最適化

参考リンク


2025年10月22日 追記: 本記事の内容はAMD Radeon 8060S + Windows 11環境での動作報告です。他の環境では異なる問題が発生する可能性があります。

ライセンス: 本記事のコードはMITライセンスで公開します。ご自由にお使いください。


タグ: #DeepSeek #OCR #AMD #ROCm #機械学習 #深層学習 #AI #Python #PyTorch

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?