Publicidad
Publicidad

Más contenido relacionado

Presentaciones para ti(20)

Similar a SSII2022 [OS3-02] Federated Learningの基礎と応用(20)

Publicidad

Más de SSII(20)

Último(20)

Publicidad

SSII2022 [OS3-02] Federated Learningの基礎と応用

  1. Federated Learningの基礎と応⽤ ⻄尾理志 東京⼯業⼤学 ⼯学院情報通信系 准教授 2022年6⽉10⽇ 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 1
  2. TokyoTech TokyoTech Outline •Federated Learningとは? •Federated Learningの原理 •Federated Learningの課題と研究紹介 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 2
  3. TokyoTech TokyoTech Background: The Age of AI スマートフォンやタブレットはデータの宝庫 ・テキスト、画像、健康状態、移動履歴… 機械学習に活⽤したい! 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 3 機械学習による様々なアプリケーションの実現 ・⾃動翻訳、ロボット制御、⾃動診断… ⾼度なモデルの学習には豊富なデータが不可⽋ データの効率的収集が課題
  4. TokyoTech TokyoTech 機械学習のためにデータを収集したいが… データの集約にはリスクが伴う •データの漏洩 •プライバシ情報の流出 ⼼理的障壁 •データを提供するのはなんとなく嫌 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 4 Federated Learning (FL, 連合機械学習)[1] [1] B. McMahan, et al.,“Communication-efficient learning of deep networks from decentralized data,” Proc. AISTATS, pp. 1273–1282, Apr. 2017. Server 盗聴 クラッキング 個⼈の特定 Motivation: データを集約せずに学習に活⽤したい!
  5. TokyoTech TokyoTech Federated Learning 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 5 分散的に取得・保持されたデータを⽤いて、データを集約することなくモデルを 訓練する⼿法 特徴 個々の端末のデータにはノータッチ 個⼈情報や機密情報も学習に活⽤し やすい エッジ データ モデル更新情報 モデルの更新 更新情報の提供 モデル配備 ユーザ データを 持つ端末 サービス を享受 モデルを使ったサービスの展開
  6. TokyoTech TokyoTech アプリケーション 医療・ヘルスケア Client: スマートフォン(ヘルスケアア プリ) 訓練データ: ⾏動ログ、⼼拍データ 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 6 スマート⼯場 Client: ⼯場等の機器管理サーバ等 訓練データ: 機器稼働履歴、振動や⾳ データ
  7. TokyoTech TokyoTech Applications: Emoji prediction from Google [4] 2022/5/17 T5: Part1 8 [4] Ramaswamy, et al., “Federated Learning for Emoji Prediction in a Mobile Keyboard,” arXiv:1906.04329. ML model predicts a Emoji based on the context. The model trained via FL achieved better prediction accuracy (+7%).
  8. TokyoTech TokyoTech Applications: Oxygen needs prediction from NVIDIA [6] [6] https://blogs.nvidia.com/blog/2020/10/05/federated-learning-covid-oxygen-needs/ 2022/5/17 T5: Part1 9 Using NVIDIA Clara Federated Learning Framework, researchers at individual hospitals were able to use a chest X-ray, patient vitals and lab values to train a local model and share only a subset of model weights back with the global model in a privacy- preserving technique called federated learning.
  9. TokyoTech TokyoTech Federated Learningの原理 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 17
  10. TokyoTech TokyoTech System model (これは概ね共通) 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 18 Server Clients (user devices) Data Client • データ保持者であり、モデルの学習に必要な情報の 提供を⾏う。ただし、データの共有は不可。 • 学習に適したデータ(前処理やラベル付け済み)を もつと仮定する • データが少量ならば5〜10回程度のモデル訓練が実⾏ 可能な程度の計算能⼒をもつ Server • 学習の管理と訓練対象のモデル(グローバルモデ ル)の更新を⾏う • Clientと通信が可能 • Clientが兼任することも可能
  11. TokyoTech TokyoTech いろいろな設定のFederated Learning Problem 2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 20 データ: IID or Non-IID 学習タスク: 教師あり学習、半教師あり学習、教師なし学習、強化学習 システム構成: Server-Client型 or 階層型 or 分散型(Server-less) 攻撃者の有無: 学習の妨害(Poisoning)、バックドア、データ盗聴 シナリオ(主にClientの数や性能) •Cross-silo FL: 10-100台程度のサーバなど 同じClientが何度も学習に参加する •Cross-device FL: 1K-1M台のスマートフォンやラップトップなど 各Clientが学習に参加するのは数回程度
  12. TokyoTech TokyoTech いろいろな設定のFederated Learning Problem 2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 21 データ: IID or Non-IID 学習タスク: 教師あり学習、半教師あり学習、教師なし学習、強化学習 システム構成: Server-Client型 or 階層型 or 分散型(Server-less) 攻撃者の有無: 学習の妨害(Poisoning)、バックドア、データ盗聴 シナリオ(主にClientの数や性能) •Cross-silo FL: 10-100台程度のサーバなど 同じClientが何度も学習に参加する •Cross-device FL: 1K-1M台のスマートフォンやラップトップなど 各Clientが学習に参加するのは数回程度 最もよくある 設定で解説
  13. TokyoTech TokyoTech 補⾜:シナリオ(主にClientの数や性能)の違い 1/2 2022/5/18 T5: Part1 22 Cross-silo federated learning Cross-device federated learning Clients: millions of devices such as mobile phones and IoT sensors Clients: small numbers of data silos such as institutions and factories Server Use cases • Keyboard next-word prediction [3] • Emoji prediction [4] • Speaker recognition [5] Client: Millions of smart phone Server Use case Oxygen need prediction [6] Client: 20 hospitals Silo A Silo B Clients
  14. TokyoTech TokyoTech 補⾜:シナリオ(主にClientの数や性能)の違い 2/2 2022/5/18 T5: Part1 23 Cross-silo federated learning Cross-device federated learning Clients: millions of devices such as mobile phones and IoT sensors Clients: small numbers of data silos such as institutions and factories Server • Clients are intermittently available • Only a portion of clients participate in round. • Clients may participate few times or once. Server • Clients are always available. • Most clients participate in every round. • Clients are identified. • Server can know the characteristics of each client and manage their participation in detail. Silo A Silo B Clients
  15. TokyoTech TokyoTech Federated Learningの具体的なアルゴリズム 2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 24 Clients (user devices) Model Data 3. モデルの更新 5. モデルの統合 2. モデルの配布 FedAvg (Federated Averaging) [1] 各Clientが訓練したモデルのパラメタ を収集し、算術平均をとることで⼀つ のモデルに統合し、学習する⽅式 • Server-Client間でやりとりするのは モデルだけ • データはそれを保持するClient⾃⾝ しか参照しない [1] B. McMahan, et al.,“Communication-efficient learning of deep networks from decentralized data,” Proc. AISTATS, pp. 1273–1282, Apr. 2017.
  16. TokyoTech TokyoTech Federated Learningの具体的なアルゴリズム 1. Client selection: サーバはラウンド(⼀連 の更新⼿順)に参加するClientを選択 2. 選択されたClientにグローバルモデルを配布 3. Local update: Clientは⾃⾝の持つデータを 使って、配布されたモデルを更新する。更 新したモデルはローカルモデルと呼ぶ。 4. ローカルモデルのパラメタをサーバに共有 する 5. Model aggregation: 共有されたパラメタを 平均し、グローバルモデルとする 6. 1~5の⼿順を繰り返す 2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 25 [1] B. McMahan, et al.,“Communication-efficient learning of deep networks from decentralized data,” Proc. AISTATS, pp. 1273–1282, Apr. 2017. Clients (user devices) Model Data 3. モデルの更新 5. モデルの統合 2. モデルの配布
  17. TokyoTech TokyoTech モデルの更新と統合(ニューラルネットワークを想定) 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 27 厳密にはミニバッチ確率的勾配降下法 Client update 通常のモデル更新のように確率的 勾配降下法によりモデルを更新 損失関数に対する勾配 ミニバッチ パラメタ(重み) for local epoch 𝑖 from 1 to 𝐸: for batch 𝑏 ∈ 𝐵: 𝑤 ← 𝑤 − 𝜂 ∇𝑙(𝑤; 𝑏) Model aggregation ローカルモデルのパラメタをデー タ数で重み付けし平均 𝑤! " ← Client update 𝑤!#$ ← ∑"∈𝑺! '" ' 𝑤! " グローバルモデル Client kのデータ数 / 総データ数
  18. TokyoTech TokyoTech 性能評価 (Supplementary PDF of [B. McMahan, et al.,“Communication-efficient learning of deep networks from decentralized data,” Proc. AISTATS, pp. 1273‒1282, Apr. 2017.]) 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 28 CIFAR-10 (画像分類タスク) ハイパーパラメータにもよるが、 データを集約した場合と同程度の 精度までモデルを訓練できている
  19. TokyoTech TokyoTech 集中型機械学習 vs. Federated Learning 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 29 Clients (user devices) Model Data モデルの更新 モデルの統合 モデルの配布 Federated Learning 集中型機械学習 Server User devices モデルの訓練 データの集約によるプライバシ 情報や機密情報漏洩の懸念 学習のための情報のみ共有し、データ は端末に保持されるため、漏洩リスク が軽減
  20. TokyoTech TokyoTech 分散機械学習 vs. Federated Learning 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 30 Servers 分散機械学習 モデルの訓練 Database Clients (user devices) Model Data モデルの更新 モデルの統合 モデルの配布 Federated Learning 学習処理の分散化に焦点 データを⼀度集約し任意に分配 クライアントごとに異なる分布 のデータを持ち、学習が困難に
  21. TokyoTech TokyoTech Federated Learningの技術的課題 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 31
  22. TokyoTech TokyoTech Federated Learning 共通の課題 学習⾯ • 学習の収束速度 集中型より遅い • Clientデータの質 - Imbalanced data - Non-iid data - Noisy label により学習の難化 • ハイパーパラメータ チューニング 従来のチューニング⽅法 では⼤きなオーバヘッド 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 32 インフラ⾯ (Cross-device FLで顕著) • 通信トラヒック モデルの共有に⼤きなト ラヒックが繰り返し発⽣ • 計算・通信性能の差 計算速度やメモリ量、伝 送速度や遅延、パケット 損失率 • Clientの参加離脱 断続的な通信環境やユー ザによる電源オフ セキュリティ⾯ (Cross-device FLで顕著) • 学習への攻撃 - Data poisoning - Model poisoning - Backdoor attack - Byzantine attack による学習の破綻やモデ ルの置き換え攻撃 • プライバシ保護 特定のClientのモデルや データを推定する攻撃な どへの対策
  23. TokyoTech TokyoTech Federated Learningの課題 学習⾯ • 学習の収束速度 集中型より遅い • Clientデータの質 - Imbalanced data - Non-iid data - Noisy label により学習の難化 • ハイパーパラメータ チューニング 従来のチューニング⽅法 では⼤きなオーバヘッド 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 33 インフラ⾯ • 通信トラヒック モデルの共有に⼤きなト ラヒックが繰り返し発⽣ • 計算・通信性能の差 計算速度やメモリ量、伝 送速度や遅延、パケット 損失率 • Clientの参加離脱 断続的な通信環境やユー ザによる電源オフ セキュリティ⾯ • 学習への攻撃 - Data poisoning - Model poisoning - Backdoor attack - Byzantine attack による学習の破綻やモデ ルの置き換え攻撃 • プライバシ保護 特定のClientのモデルや データを推定する攻撃に 対する対策、差分プライ バシによる評価
  24. TokyoTech TokyoTech Non-iid (not independent and identically distributed)データ 各Clientの持つデータが従う分布が、全クライアントのデータを集約し た場合の分布と⼀致しない状況 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 35 Class Distribution of data 統合してもモデル性能 があまり向上しない データが⼤きく異なるため Model 2と極端に異なるモデル を獲得 Model 1 Model training Class Client 1 Client 2 Model 2 Aggregated model Distribution of data
  25. TokyoTech TokyoTech パラメタ空間上でのパラメタ更新の図⽰ 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 36 グローバルモデル 𝒘𝒕 ローカルモデル 𝒘",$%& ローカルモデル 𝒘',$%& 統合後の グローバルモデル 𝒘𝒕%𝟏 a) iidにおけるパラメタ更新 b) non-iidにおけるパラメタ更新 全Clientにおいて概ね同じ⽅向にモデル が更新され、グローバルモデルが更新さ れていく Clientごとに更新⽅向が⼤きく異なり、 グローバルモデルがあまり更新されな い状況が発⽣しうる 𝒘",$%& 𝒘',$%& モデルがほとんど 更新されない
  26. TokyoTech TokyoTech 通信トラヒック 学習をある程度収束させるためには 数⼗〜数百ラウンド程度必要 Non-iidではよりラウンド数が必要 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 38 学習に必要なラウンド数:𝑁 10~1000 1ラウンドあたりのClient数:𝐶 5~100 モデルのデータサイズ:𝐷 100KB ~ 1GB 総トラヒック量 = 𝟐𝑫𝑪𝑵 10 MB ~ 200 TB
  27. TokyoTech TokyoTech 課題解決に向けたFLアルゴリズムの改良 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 39
  28. TokyoTech TokyoTech Non-iidデータ 1/2 Local updateの改良 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 40 SCAFFOLD [N1] グローバルモデルの更新⽅向を推定 し、その⽅向にローカルモデルが更 新されるよう補正を加える [N1] Karimireddy, Sai Praneeth, et al. "Scaffold: Stochastic controlled averaging for federated learning." ICML 2020. 引⽤ [N2] Wang, Hao, et al. "Optimizing federated learning on non-iid data with reinforcement learning." IEEE INFOCOM, 2020. Client selectionの改良 FAVOR [N2] モデル性能の向上に寄与するClient セット選択戦略を強化学習により学習 引⽤
  29. TokyoTech TokyoTech Non-iidデータ 2/2 2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 41 Model aggregationの改良 [N4] H. Wu and P. Wang, “Fast-Convergent Federated Learning With Adaptive Weighting,” IEEE Transactions on Cognitive Communications and Networking, vol. 7, no. 4, pp. 1078-1088, Dec. 2021. FedAdp [N4] Clientの学習への貢献(更新時の勾配をもと に算出)をもとにAggregation時の重みを調整 少量のIIDデータの活⽤ [N3] N. Yoshida, T. Nishio, et al., "Hybrid-FL for Wireless Networks: Cooperative Learning Mechanism Using Non-IID Data," IEEE ICC 2020. Hybrid FL [N3] 極少数 (~1%)のClientはデータのアップロードを許可 アップロードデータでサーバ側にIIDデータセットを 構築し、モデル更新に活⽤ 2 4 6 8 10 Mode of r, µ 0.0 0.2 0.4 0.6 0.8 1.0 Accuracy (a) CIFAR-10. Centralized model training FedCS IID non-IID 2 4 6 8 10 Mode of r, µ 0.0 0.2 0.4 0.6 0.8 1.0 Accuracy (a) CIFAR-10. Centralized model training FedCS Hybrid-FL (maxThroughput/minCV) 2 4 6 Mode of r, µ 0.0 0.2 0.4 0.6 0.8 1.0 Accuracy (b) Fashion MNIST.
  30. TokyoTech TokyoTech Non-iidデータ 2/2 少量のIIDデータの活⽤ 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 42 [N3] N. Yoshida, T. Nishio, et al., "Hybrid-FL for Wireless Networks: Cooperative Learning Mechanism Using Non-IID Data," IEEE ICC 2020. Hybrid FL [N3] 極少数 (~1%)のClientはデータのアップロードを許可 アップロードデータでサーバ側にIIDデータセットを構築し、モデル更新に活⽤ 2 4 6 8 10 Mode of r, µ 0.0 0.2 0.4 0.6 0.8 1.0 Accuracy (a) CIFAR-10. Centralized model training FedCS Hybrid-FL (maxThroughput/minCV) 2 4 6 8 10 Mode of r, µ 0.0 0.2 0.4 0.6 0.8 1.0 Accuracy (b) Fashion MNIST. IID non-IID IID non-IID
  31. TokyoTech TokyoTech 通信トラヒック削減 モデルの圧縮 巨⼤なニューラルネットワークに対し、 パラメタの量⼦化、枝刈り、パラメタの少ないモデルに 置き換え等を⾏い、モデルのデータサイズを⼩さくする。 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 43 空中計算 1 3 3+1=4 空中計算 データをアナログ振幅変調し、多数のClientが同時送信するこ とで、電波の重畳現象により受信振幅値から総和を求める⼿法 モデル統合に応⽤することで、送信局の数が⾮常に多い場合に チャネル専有時間が削減される スパース化 モデルを共有しない学習⽅法: Distillation-based Federated Learning [C2] K. Yang et al., "Federated learning via over-the-air computation." IEEE Trans. Wireless Commun. 2020. [C1] F. Haddadpour, et al., "Federated learning with compression: Unified analysis and sharp guarantees." AISTATS 2021.
  32. TokyoTech TokyoTech Distillation based Semi-Supervised Federated Learning (DS-FL) [A] [A] S. Itahara, T. Nishio, et al.,, “Distillation-Based Semi-Supervised Federated Learning for Communication-Efficient Collaborative Training with Non-IID Private Data,” IEEE Trans. Mobile Compt. 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 44 Distillation (蒸留) モデルの出⼒であるロジットを⽤いた学習⽅式 モデルの代わりにサイズの⼩さいロジットを⽤いることでトラヒック を⼤幅に削減 Semi-supervised learning (半教師あり学習) ラベル付きデータに加えて、ラベルなしデータも活⽤する機械学習 従来はモデルの汎化性能向上に⽤いられることが多い 本⽅式ではDistillationをFLに組み込むために活⽤ 通信トラヒックを⼤幅削減可能(FedAvgの1/50)な学習⼿法
  33. TokyoTech TokyoTech DS-FLと従来⼿法(FedAvg)の⽐較 [A] S. Itahara, T. Nishio, et al.,, “Distillation-Based Semi-Supervised Federated Learning for Communication-Efficient Collaborative Training with Non-IID Private Data,” IEEE Trans. Mobile Compt. 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 46 データサイズの⼤きいモデルを何度も 共有するため⼤きなトラヒックが発⽣ Clients (user devices) Model Data モデルの更新 出⼒統合とモデル訓練 Model Logit モデルの出⼒ 提案⼿法 Clients (user devices) Model Data モデルの更新 モデルの統合 モデルの配布 従来のFL モデルの出⼒情報を⽤いて学習するこ とで学習時のトラヒックを⼤幅に削減
  34. TokyoTech TokyoTech DistillationによるNon-iid dataからの学習 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 49 共有データ (ラベルなし) 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 統合したロジット 統合したロジット でラベルなしデー タをラベル付け Clientのデータ セット (ラベル付)
  35. TokyoTech TokyoTech 性能評価:諸元 10クラス画像分類タスク(MNIST)を学習したときの精度とトラヒックを評価 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 50 学習率 0.1 ミニバッチサイズ 100 エポック数 5 モデルの構造 畳み込み2層 + 全結合2層 モデルのパラメタ数 584,458 (2.3MB) 最適化⼿法 SGD 端末数 100 clients 公開データ数 20,000 images 各端末の 端末データ数 400 images 各端末の持つ端末データの クラス数 2 classes
  36. TokyoTech TokyoTech 性能評価結果 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 51 98% less traffic ⼀定精度を達成するまでに必要な 通信トラヒックを1/50以下まで削減 MNISTデータを⽤いたnon-IIDデータ環境でのFederated Learning実験 誤ってラベル付けされたデータが 混⼊した場合も⾼い予測精度を維持
  37. TokyoTech TokyoTech まとめ 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 52 Federated Learning 分散的に取得・保持されたデータを⽤いて、データを集約することなくモデルを訓練 プライバシ情報や機密情報を機械学習に活⽤しやすくなる エッジ データ モデル更新情報 モデルの更新 更新情報の提供 モデル配備 ユーザ データを 持つ端末 サービス を享受 モデルを使ったサービスの展開 技術課題 • 学習の計算リソース/データ効率 • 通信リソース効率 • 攻撃・セキュリティ 競争激しい分野ですが 課題も多く⾯⽩い
  38. TokyoTech TokyoTech その他、参考になりそうな資料 Federated Learning Tutorial @NeurIPS2020 https://sites.google.com/view/fl-tutorial/ 2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 53
Publicidad