バッチ正規化:ディープラーニングの学習を安定させるBatch Normalization

AIを知りたい
先生、ニューラルネットワークを学習させているのですが、損失が全然下がらなかったり、突然発散してしまうことがあります。何か解決策はありますか?

AIエンジニア
それはディープラーニングの典型的な悩みだね。解決策の1つが「バッチ正規化(Batch Normalization、BN)」だよ。バッチ正規化とは、ニューラルネットワークの各レイヤーの入力を、ミニバッチ単位で平均0・分散1に正規化する手法なんだ。2015年にGoogleのIoffeとSzegedyが発表して以来、ほとんどすべてのCNN(畳み込みニューラルネットワーク)に組み込まれている基本技術だよ。

AIを知りたい
なぜ正規化すると学習が安定するんですか?

AIエンジニア
ニューラルネットワークでは、あるレイヤーのパラメータが更新されると、次のレイヤーへの入力分布が変わってしまうんだ。これを「内部共変量シフト(Internal Covariate Shift)」と呼ぶよ。バッチ正規化はこの入力分布の変動を抑えることで、各レイヤーが安定した分布のデータを受け取れるようにする。その結果、より大きな学習率を使えるようになり、学習が高速化し、初期値への依存も低減されるんだ。
バッチ正規化とは。
バッチ正規化(Batch Normalization)は、2015年にIoffeとSzegedyが論文「Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift」で提案した、ディープラーニングの学習を安定・高速化するための正規化手法です。ミニバッチ内のデータについて各特徴量の平均と分散を計算し、平均0・分散1に正規化した後、学習可能なスケール(γ)とシフト(β)パラメータで変換します。これにより、内部共変量シフトが抑制され、学習率を大きくしても学習が安定します。ResNet、EfficientNet、YOLOなど主要なCNNアーキテクチャのほぼ全てに採用されています。一方、Transformerアーキテクチャ(GPT、BERTなど)ではLayer Normalizationが主に使用されており、用途に応じた正規化手法の選択が重要です。
バッチ正規化と他の正規化手法の比較
バッチ正規化以外にもさまざまな正規化手法が開発されており、タスクやアーキテクチャに応じて使い分けます。
| 手法 | 正規化の軸 | 主な用途 | メリット | デメリット |
|---|---|---|---|---|
| Batch Normalization | バッチ方向 | CNN(画像認識) | 学習が安定、高速化、正則化効果 | バッチサイズが小さいと不安定 |
| Layer Normalization | 特徴量方向 | Transformer(NLP) | バッチサイズに依存しない | CNNでは効果が限定的 |
| Instance Normalization | 各サンプル・各チャネル | 画像スタイル変換 | スタイル情報の正規化に適する | 分類タスクではBNに劣る |
| Group Normalization | チャネルグループ方向 | 物体検出、セグメンテーション | 小バッチでも安定 | グループ数の設定が必要 |
| RMSNorm | 特徴量方向(RMS) | Llama, GPT系 | LayerNormより計算効率が良い | 比較的新しく研究が進行中 |

AIを知りたい
TransformerではLayer Normalizationが使われるんですね。なぜBatch Normalizationではダメなんですか?

AIエンジニア
いい質問だね。主な理由は2つあるよ。1つ目は、NLPでは入力テキストの長さがバラバラなので、パディング(空白埋め)が発生し、バッチ方向の統計量が不安定になること。2つ目は、テキスト生成のような自己回帰タスクでは、将来のトークンの情報を使えないため、バッチ全体の統計量を計算するBNが使いにくいことだ。Layer Normalizationは各サンプルの特徴量方向で正規化するから、これらの問題を回避できるんだよ。
バッチ正規化の実装と使用上の注意点
バッチ正規化を効果的に使うためのポイントを整理します。
| 注意点 | 詳細 | 対策 |
|---|---|---|
| バッチサイズの影響 | バッチサイズが小さい(16未満)と統計量が不安定 | Group NormalizationやSync BNを検討 |
| 学習時と推論時の挙動の違い | 学習時はミニバッチ統計量、推論時は移動平均を使用 | model.eval()の切り替えを忘れない |
| 配置位置 | Conv→BN→ReLUの順が一般的 | 活性化関数の前に配置するのが標準 |
| ドロップアウトとの併用 | BNとDropoutの同時使用は性能低下の報告あり | BNを使う場合はDropoutを外すことが多い |
| 転移学習時の注意 | 事前学習のBN統計量と新データの分布が異なる | ファインチューニング時にBN層の統計量を更新 |

AIを知りたい
学習時と推論時で挙動が違うのは知りませんでした。model.eval()を忘れるとどうなるんですか?

AIエンジニア
推論時にmodel.eval()を呼ばないと、BN層がミニバッチの統計量を使い続けるため、入力データのバッチ構成によって出力が変わってしまうという問題が起きるよ。これはPyTorchでよくあるバグの原因だ。必ずmodel.eval()を呼んで、学習中に蓄積した移動平均の統計量を使うようにしよう。バッチ正規化は「設置して終わり」ではなく、学習と推論の切り替えまで意識することが重要だよ。

AIを知りたい
正規化1つとっても奥が深いですね。CNNならBN、TransformerならLN、小バッチならGNと、状況に応じて使い分けるのが大切なんですね!
