3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DeepseekR1-0528の挑戦(学習できない!)

Last updated at Posted at 2025-10-06

はじめに

先日、松尾研LLMコンペで DeepSeekR1-0528 の学習に挑戦しました。(Phase1)

しかし結論から言うと――学習はできませんでした。

  • モデルロードはなんとか成功(FP8 → BF16 → FP4)
  • いざ学習を試みるも、MoE層をライブラリが認識できず、ハードコーディングで対応
  • しかし時間切れで断念
  • 最終的に Qwen3-235B-A22BをSFT学習 する結果に

この記事では、

「なぜ学習できなかったのか」

その原因と解決策を整理します。


アジェンダ

  1. DeepSeekR1-0528とは?
  2. DeepSeekR1-0528の特殊性
  3. 自分たちが歩んだプロセス
  4. 失敗要因
  5. まとめ

DeepSeekR1-0528とは?

概要

  • DeepSeek社が開発した大規模言語モデル

モデル構造

  • MoE(Mixture-of-Experts) アーキテクチャ採用
  • 総パラメータ数:6710億
  • 推論設定時は332B程度(Tensor type:F8)が稼働

特徴

  • 最大128Kトークンの長いコンテキストに対応
  • GRPO による強化学習

派生

  • DeepseekR1-0528-Qwen3-8B

Qwen3-8Bとの違い

項目 DeepSeekR1-0528 DeepSeekR1-0528-Qwen3-8B
アーキテクチャ MoE (671B) Transformer (8B)
ベースモデル DeepSeek-V3-Base Qwen3-8B Base (Alibaba製)
訓練方法 GRPO中心 RLとSFTを組み合わせた複雑なパイプラインで学習 SFT R1の思考プロセスを教師データとして学習
FTの難易度 極めて困難 巨大なモデルサイズとMoE構造のため追加学習で性能バランスが崩れやすい 比較的容易 標準的なアーキテクチャのためLoRAや追加のSFT/GRPOによる再調整が可能

DeepSeekR1-0528の特殊性

  1. 公開モデルは推論用の FP8形式 → 学習可能な型(BF16など)に変換が必須
  2. 671Bの巨大サイズ → 単一GPUに収まらないため多数GPU+分散並列が必須
  3. MoEアーキテクチャ → エキスパート層が独自命名で、専用のフレームワークが必要

DeepSeekR1-0528のロードから学習への大変さ!

DeepseekR1-0528の公式のモデルをロードして学習しようとするとロードが進みません・・・・

なぜか・・・?

理由は、DeepseekR1-0528の公式のモデルはFP8で推論モードなので、学習はできません。

なので、FP8からBF16にもどし、それから変換しないと使えません・・・!

方法は、

  1. 公式ページのinference以下にrequirements.txtで環境構築
  2. inference/fp8_cast_bf16.pyを実行する

これによってようやっとスタートラインです!


失敗原因

しかし、私達はスタートラインに立ち、ひたすらやっていたのですが、どうやってもうまくいきません・・・

ライブラリを変えても、ラッパ(Monckypatchで)しても難しい・・・なぜだ!!!!!

後々調べると、私たちのクリティカルな問題は FSDPを使ったことでした。

  • FSDPは特殊な場合MoEを理解できない
    • エキスパートを無視して重みをバラバラに分割
    • MoEの効率的通信を破壊 → ボトルネック発生
  • auto_wrap_policy 設定が難しくエラー多発
  • FSDP通信とMoE特有のトークンルーティング通信が衝突

このような要件で失敗したようでした・・・

エラーログ(最高到達点)
    [__main__][INFO] WandB initialized successfully
    [__main__][INFO] Loading tokenizer...
    [__main__][INFO] Chat template found in tokenizer
    [__main__][INFO] Tokenizer special tokens:
    [__main__][INFO]   BOS token: <|begin▁of▁sentence|>
    [__main__][INFO]   EOS token: <|end▁of▁sentence|>
    [__main__][INFO]   PAD token: <|end▁of▁sentence|>
    [__main__][INFO]   UNK token: None
    [__main__][INFO] === 4bit量子化済みモデルを使用 ===
    [__main__][INFO] モデルは既に4bit量子化されているため、追加の量子化設定は行いません
    [__main__][INFO] Loading model configuration...
    [__main__][INFO] Found existing quantization config in model config
    [__main__][INFO] Keeping existing quantization config for pre-quantized model
    [__main__][INFO] Quantization method: bitsandbytes
    [__main__][INFO] === Model Configuration Analysis ===
    [__main__][INFO] Model config type: <class 'transformers_modules.quantized_deepseek_671b_4bit.configuration_deepseek.DeepseekV3Config'>
    [__main__][INFO] Model architectures: ['DeepseekV3ForCausalLM']
    [__main__][INFO] Model type: deepseek_v3
    [__main__][INFO] Hidden size: 7168
    [__main__][INFO] Num attention heads: 128
    [__main__][INFO] Num hidden layers: 61
    [__main__][INFO] Vocab size: 129280
    [__main__][INFO] MoE routed experts: 256
    [__main__][INFO] MoE shared experts: 1
    [__main__][INFO] 4bit量子化済みモデルを読み込み(FSDP用にdevice_map=None)
    [__main__][INFO] Loading DeepSeek model for FSDP...
    [__main__][INFO] Model kwargs (without config): {'torch_dtype': torch.bfloat16, 'trust_remote_code': True, 'low_cpu_mem_usage': True, 'use_safetensors': True, 'local_files_only': False, 'device_map': None}
     Loading checkpoint shards: 100%|██████████| 70/70 [36:38<00:00, 31.41s/it]
    [__main__][INFO] ✅ Model loaded successfully!
    [__main__][INFO] Model device: cpu
    [__main__][INFO] Model dtype: torch.bfloat16
    [__main__][INFO] === Model Structure Analysis for FSDP ===
    [__main__][INFO] Model type: DeepseekV3ForCausalLM
    [__main__][INFO] Model class: <class 'transformers_modules.quantized_deepseek_671b_4bit.modeling_deepseek.DeepseekV3ForCausalLM'>
    [__main__][INFO] Detected 406 transformer layers (sampled)
    [__main__][INFO] Detected 406 MoE experts (sampled)
    [__main__][INFO] Model structure analysis completed (fast mode)
    [__main__][INFO] === LoRA Configuration ===
    [__main__][INFO] LoRA target modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
    [__main__][INFO] === LoRA Parameter Type Unification for FSDP ===
    [__main__][INFO] Converted 89576 parameters to torch.bfloat16
     trainable params: 6,620,610,560 || all params: 677,647,029,760 || trainable%: 0.9770
    [__main__][INFO] LoRA applied successfully!
    [__main__][INFO] Loading dataset...
    [__main__][INFO] Loading dataset from: team-suzuki/SFT_000_DeepSeek-R1-0528 (source: huggingface)
    [__main__][INFO] Loaded 1828 samples
    [__main__][INFO] Sample keys: ['text']
    [__main__][INFO] Text field 'text' found
    [__main__][INFO] Sample text type: <class 'str'>
    [__main__][INFO] Sample text length: 2521
    [__main__][INFO] Dataset text field normalized to string format
    [__main__][INFO] Initializing SFTTrainer...
    [__main__][INFO] Applied FSDP auto wrap policy patch for DeepSeek-R1
     No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
    [__main__][INFO] ✅ SFTTrainer initialized successfully!
    [__main__][INFO] === Starting training with 4bit quantization ===
    [__main__][INFO] 🚀 4bit量子化済みモデルでトレーニング開始
    [__main__][INFO] 💾 メモリ使用量: 既に最適化済み
    [__main__][INFO] GPU 0: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] GPU 1: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] GPU 2: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] GPU 3: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] GPU 4: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] GPU 5: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] GPU 6: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] GPU 7: 0.00GB allocated, 0.00GB cached
    [__main__][INFO] 🎯 トレーニング実行中...
    [__main__][INFO] 📈 最大ステップ数: 500
    [__main__][INFO] 📊 ログ出力間隔: 2 ステップ
    [__main__][INFO] 💾 保存間隔: 25 ステップ
    [__main__][WARNING] Using dummy FSDP auto wrap policy for DeepSeek-R1 compatibility
     /home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/accelerator.py:1805: UserWarning: Upcasted low precision parameters in DeepseekV3Model because mixed precision turned on in FSDP. Affects: layers.0.self_attn.q_a_layernorm.weight, layers.0.self_attn.kv_a_layernorm.weight, layers.0.self_attn.o_proj.lora_A.default.weight, layers.0.self_attn.o_proj.lora_B.default.weight, layers.0.input_layernorm.weight, layers.0.post_attention_layernorm.weight, layers.1.self_attn.q_a_layernorm.weight, layers.1.self_attn.kv_a_layernorm.weight, layers.1.self_attn.o_proj.lora_A.default.weight, layers.1.self_attn.o_proj.lora_B.default.weight, layers.1.input_layernorm.weight, layers.1.post_attention_layernorm.weight, layers.2.self_attn.q_a_layernorm.weight, layers.2.self_attn.kv_a_layernorm.weight, layers.2.self_attn.o_proj.lora_A.default.weight, layers.2.self_attn.o_proj.lora_B.default.weight, layers.2.input_layernorm.weight, layers.2.post_attention_layernorm.weight, layers.3.self_attn.q_a_layernorm.weight, layers.3.self_attn.kv_a_layernorm.weight, layers.3.self_attn.o_proj.lora_A.default.weight, layers.3.self_attn.o_proj.lora_B.default.weight, layers.3.input_layernorm.weight, layers.3.post_attention_layernorm.weight, layers.4.self_attn.q_a_layernorm.weight, layers.4.self_attn.kv_a_layernorm.weight, layers.4.self_attn.o_proj.lora_A.default.weight, layers.4.self_attn.o_proj.lora_B.default.weight, layers.4.input_layernorm.weight, layers.4.post_attention_layernorm.weight, layers.5.self_attn.q_a_layernorm.weight, layers.5.self_attn.kv_a_layernorm.weight, layers.5.self_attn.o_proj.lora_A.default.weight, layers.5.self_attn.o_proj.lora_B.default.weight, layers.5.input_layernorm.weight, layers.5.post_attention_layernorm.weight, layers.6.self_attn.q_a_layernorm.weight, layers.6.self_attn.kv_a_layernorm.weight, layers.6.self_attn.o_proj.lora_A.default.weight, layers.6.self_attn.o_proj.lora_B.default.weight, layers.6.input_layernorm.weight, layers.6.post_attention_layernorm.weight, layers.7.self_attn.q_a_layernorm.weight, layers.7.self_attn.kv_a_layernorm.weight, layers.7.self_attn.o_proj.lora_A.default.weight, layers.7.self_attn.o_proj.lora_B.default.weight, layers.7.input_layernorm.weight, layers.7.post_attention_layernorm.weight, layers.8.self_attn.q_a_layernorm.weight, layers.8.self_attn.kv_a_layernorm.weight, layers.8.self_attn.o_proj.lora_A.default.weight, layers.8.self_attn.o_proj.lora_B.default.weight, layers.8.input_layernorm.weight, layers.8.post_attention_layernorm.weight, layers.9.self_attn.q_a_layernorm.weight, layers.9.self_attn.kv_a_layernorm.weight, layers.9.self_attn.o_proj.lora_A.default.weight, layers.9.self_attn.o_proj.lora_B.default.weight, layers.9.input_layernorm.weight, layers.9.post_attention_layernorm.weight, layers.10.self_attn.q_a_layernorm.weight, layers.10.self_attn.kv_a_layernorm.weight, layers.10.self_attn.o_proj.lora_A.default.weight, layers.10.self_attn.o_proj.lora_B.default.weight, layers.10.input_layernorm.weight, layers.10.post_attention_layernorm.weight, layers.11.self_attn.q_a_layernorm.weight, layers.11.self_attn.kv_a_layernorm.weight, layers.11.self_attn.o_proj.lora_A.default.weight, layers.11.self_attn.o_proj.lora_B.default.weight, layers.11.input_layernorm.weight, layers.11.post_attention_layernorm.weight, layers.12.self_attn.q_a_layernorm.weight, layers.12.self_attn.kv_a_layernorm.weight, layers.12.self_attn.o_proj.lora_A.default.weight, layers.12.self_attn.o_proj.lora_B.default.weight, layers.12.input_layernorm.weight, layers.12.post_attention_layernorm.weight, layers.13.self_attn.q_a_layernorm.weight, layers.13.self_attn.kv_a_layernorm.weight, layers.13.self_attn.o_proj.lora_A.default.weight, layers.13.self_attn.o_proj.lora_B.default.weight, layers.13.input_layernorm.weight, layers.13.post_attention_layernorm.weight, layers.14.self_attn.q_a_layernorm.weight, layers.14.self_attn.kv_a_layernorm.weight, layers.14.self_attn.o_proj.lora_A.default.weight, layers.14.self_attn.o_proj.lora_B.default.weight, layers.14.input_layernorm.weight, layers.14.post_attention_layernor
       warnings.warn(
     /home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/accelerator.py:1805: UserWarning: Upcasted low precision parameters in DeepseekV3MLP because mixed precision turned on in FSDP. Affects: gate_proj.lora_A.default.weight, gate_proj.lora_B.default.weight, up_proj.lora_A.default.weight, up_proj.lora_B.default.weight, down_proj.lora_A.default.weight, down_proj.lora_B.default.weight.
       warnings.warn(
     /home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/accelerator.py:1805: UserWarning: Upcasted low precision parameters in DeepseekV3MoE because mixed precision turned on in FSDP. Affects: experts.0.gate_proj.lora_A.default.weight, experts.0.gate_proj.lora_B.default.weight, experts.0.up_proj.lora_A.default.weight, experts.0.up_proj.lora_B.default.weight, experts.0.down_proj.lora_A.default.weight, experts.0.down_proj.lora_B.default.weight, experts.1.gate_proj.lora_A.default.weight, experts.1.gate_proj.lora_B.default.weight, experts.1.up_proj.lora_A.default.weight, experts.1.up_proj.lora_B.default.weight, experts.1.down_proj.lora_A.default.weight, experts.1.down_proj.lora_B.default.weight, experts.2.gate_proj.lora_A.default.weight, experts.2.gate_proj.lora_B.default.weight, experts.2.up_proj.lora_A.default.weight, experts.2.up_proj.lora_B.default.weight, experts.2.down_proj.lora_A.default.weight, experts.2.down_proj.lora_B.default.weight, experts.3.gate_proj.lora_A.default.weight, experts.3.gate_proj.lora_B.default.weight, experts.3.up_proj.lora_A.default.weight, experts.3.up_proj.lora_B.default.weight, experts.3.down_proj.lora_A.default.weight, experts.3.down_proj.lora_B.default.weight, experts.4.gate_proj.lora_A.default.weight, experts.4.gate_proj.lora_B.default.weight, experts.4.up_proj.lora_A.default.weight, experts.4.up_proj.lora_B.default.weight, experts.4.down_proj.lora_A.default.weight, experts.4.down_proj.lora_B.default.weight, experts.5.gate_proj.lora_A.default.weight, experts.5.gate_proj.lora_B.default.weight, experts.5.up_proj.lora_A.default.weight, experts.5.up_proj.lora_B.default.weight, experts.5.down_proj.lora_A.default.weight, experts.5.down_proj.lora_B.default.weight, experts.6.gate_proj.lora_A.default.weight, experts.6.gate_proj.lora_B.default.weight, experts.6.up_proj.lora_A.default.weight, experts.6.up_proj.lora_B.default.weight, experts.6.down_proj.lora_A.default.weight, experts.6.down_proj.lora_B.default.weight, experts.7.gate_proj.lora_A.default.weight, experts.7.gate_proj.lora_B.default.weight, experts.7.up_proj.lora_A.default.weight, experts.7.up_proj.lora_B.default.weight, experts.7.down_proj.lora_A.default.weight, experts.7.down_proj.lora_B.default.weight, experts.8.gate_proj.lora_A.default.weight, experts.8.gate_proj.lora_B.default.weight, experts.8.up_proj.lora_A.default.weight, experts.8.up_proj.lora_B.default.weight, experts.8.down_proj.lora_A.default.weight, experts.8.down_proj.lora_B.default.weight, experts.9.gate_proj.lora_A.default.weight, experts.9.gate_proj.lora_B.default.weight, experts.9.up_proj.lora_A.default.weight, experts.9.up_proj.lora_B.default.weight, experts.9.down_proj.lora_A.default.weight, experts.9.down_proj.lora_B.default.weight, experts.10.gate_proj.lora_A.default.weight, experts.10.gate_proj.lora_B.default.weight, experts.10.up_proj.lora_A.default.weight, experts.10.up_proj.lora_B.default.weight, experts.10.down_proj.lora_A.default.weight, experts.10.down_proj.lora_B.default.weight, experts.11.gate_proj.lora_A.default.weight, experts.11.gate_proj.lora_B.default.weight, experts.11.up_proj.lora_A.default.weight, experts.11.up_proj.lora_B.default.weight, experts.11.down_proj.lora_A.default.weight, experts.11.down_proj.lora_B.default.weight, experts.12.gate_proj.lora_A.default.weight, experts.12.gate_proj.lora_B.default.weight, experts.12.up_proj.lora_A.default.weight, experts.12.up_proj.lora_B.default.weight, experts.12.down_proj.lora_A.default.weight, experts.12.down_proj.lora_B.default.weight, experts.13.gate_proj.lora_A.default.weight, experts.13.gate_proj.lora_B.default.weight, experts.13.up_proj.lora_A.default.weight, experts.13.up_proj.lora_B.default.weight, experts.13.down_proj.lora_A.default.weight, experts.13.down_proj.lora_B.default.weight, experts.14.gate_proj.lora_A.default.weight, experts.14.gate_proj.lora_B.default.weight, experts.14.up_proj.lora_A.default.weight, experts.14.up_proj.lora_B.default.weight, experts.14.down_proj.lora_A.default.weight, experts.14.down_proj.lora_B.default.weight, experts.15.gate_proj.l
       warnings.warn(
     /home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/accelerator.py:1811: UserWarning: FSDP upcast of low precision parameters may affect the precision of model checkpoints.
       warnings.warn(
       0%|          | 0/500 [00:00<?, ?it/s]Traceback (most recent call last):
       File "/home/Competition2025/P05/P05U016/team_suzuki/train/hara_train_fsdp/fsdp_code_ori/deepseekr1_fsdp_4bit.py", line 982, in <module>
         main()
       File "/home/Competition2025/P05/P05U016/team_suzuki/train/hara_train_fsdp/fsdp_code_ori/deepseekr1_fsdp_4bit.py", line 949, in main
         trainer.train()
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 2237, in train
         return inner_training_loop(
                ^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 2578, in _inner_training_loop
         tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 914, in training_step
         return super().training_step(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 3792, in training_step
         loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 868, in compute_loss
         (loss, outputs) = super().compute_loss(
                           ^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 3879, in compute_loss
         outputs = model(**inputs)
                   ^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
         return self._call_impl(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
         return forward_call(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 818, in forward
         return model_forward(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 806, in __call__
         return convert_to_fp32(self.model_forward(*args, **kwargs))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
         return func(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 864, in forward
         output = self._fsdp_wrapped_module(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
         return self._call_impl(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
         return forward_call(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 818, in forward
         return model_forward(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 806, in __call__
         return convert_to_fp32(self.model_forward(*args, **kwargs))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
         return func(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/peft/peft_model.py", line 1850, in forward
         return self.base_model(
                ^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
         return self._call_impl(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
         return forward_call(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 222, in forward
         return self.model.forward(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 1601, in forward
         outputs = self.model(
                   ^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
         return self._call_impl(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
         return forward_call(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 864, in forward
         output = self._fsdp_wrapped_module(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
         return self._call_impl(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
         return forward_call(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 1470, in forward
         layer_outputs = decoder_layer(
                         ^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
         return self._call_impl(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
         return forward_call(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 1202, in forward
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                               ^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
         return self._call_impl(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
         return forward_call(*args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
       File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 820, in forward
         raise ValueError(
     ValueError: Attention weights should be of size (1, 128, 1024, 2048), but is torch.Size([1, 128, 1024, 1024])
     [rank0]: Traceback (most recent call last):
     [rank0]:   File "/home/Competition2025/P05/P05U016/team_suzuki/train/hara_train_fsdp/fsdp_code_ori/deepseekr1_fsdp_4bit.py", line 982, in <module>
     [rank0]:     main()
     [rank0]:   File "/home/Competition2025/P05/P05U016/team_suzuki/train/hara_train_fsdp/fsdp_code_ori/deepseekr1_fsdp_4bit.py", line 949, in main
     [rank0]:     trainer.train()
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 2237, in train
     [rank0]:     return inner_training_loop(
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 2578, in _inner_training_loop
     [rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
     [rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 914, in training_step
     [rank0]:     return super().training_step(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 3792, in training_step
     [rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 868, in compute_loss
     [rank0]:     (loss, outputs) = super().compute_loss(
     [rank0]:                       ^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 3879, in compute_loss
     [rank0]:     outputs = model(**inputs)
     [rank0]:               ^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
     [rank0]:     return self._call_impl(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
     [rank0]:     return forward_call(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 818, in forward
     [rank0]:     return model_forward(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 806, in __call__
     [rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
     [rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
     [rank0]:     return func(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 864, in forward
     [rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
     [rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
     [rank0]:     return self._call_impl(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
     [rank0]:     return forward_call(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 818, in forward
     [rank0]:     return model_forward(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/accelerate/utils/operations.py", line 806, in __call__
     [rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
     [rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
     [rank0]:     return func(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/peft/peft_model.py", line 1850, in forward
     [rank0]:     return self.base_model(
     [rank0]:            ^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
     [rank0]:     return self._call_impl(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
     [rank0]:     return forward_call(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 222, in forward
     [rank0]:     return self.model.forward(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 1601, in forward
     [rank0]:     outputs = self.model(
     [rank0]:               ^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
     [rank0]:     return self._call_impl(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
     [rank0]:     return forward_call(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 864, in forward
     [rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
     [rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
     [rank0]:     return self._call_impl(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
     [rank0]:     return forward_call(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 1470, in forward
     [rank0]:     layer_outputs = decoder_layer(
     [rank0]:                     ^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
     [rank0]:     return self._call_impl(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
     [rank0]:     return forward_call(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 1202, in forward
     [rank0]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
     [rank0]:                                                           ^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
     [rank0]:     return self._call_impl(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
     [rank0]:     return forward_call(*args, **kwargs)
     [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     [rank0]:   File "/home/Competition2025/P05/P05U016/deepseek-r1-fsdp/huggingface_cache/modules/transformers_modules/quantized_deepseek_671b_4bit/modeling_deepseek.py", line 820, in forward
     [rank0]:     raise ValueError(
     [rank0]: ValueError: Attention weights should be of size (1, 128, 1024, 2048), but is torch.Size([1, 128, 1024, 1024])

分散学習における通信

  • Accelerate
    • DeepSpeedなどのバックエンドを呼ぶラッパー
    • パイプライン並列は非対応
  • DeepSpeed
    • ZeROでのメモリ削減に加え、MoE特有のAll-to-All通信を最適化
    • 事実上の標準(デファクトスタンダード)

👉 解決策は DeepSpeedを基軸 にすることではないかと・・・!!!


まとめ

  • DeepSeekR1-0528は巨大サイズ+MoE構造が特殊
  • DeepseekLayerの特殊な命名
  • 失敗要因は FSDPを利用したこと
  • 今後のアプローチ
    • DeepSpeed を基盤にEP・PPを考慮して実装

「モデルの特殊性を考え、適切なツールを選択して学習する!」

これが今回の最大の教訓です。

参考文献

プロジェクトのクレジット

本プロジェクトは、国立研究開発法人新エネルギー・産業技術開発機構(NEDO)の「日本語版医療特化型LLMの社会実装に向けた安全性検証・実証」における基盤モデルの開発プロジェクトの一環として行われました。

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?