面對大規模 AI 模型迅速成長,企業正在尋找更高效、更具彈性的基礎架構,以支撐日益複雜的運算需求。
Google Cloud 近期宣布在 GKE 上原生整合 Ray 與 Cloud TPU,讓企業能更直接地使用 TPU 的高效能運算能力,同時享受 Ray 帶來的彈性調度與自動化工作流程。這項整合不僅簡化了大型 AI 訓練與推理的部署流程,也讓企業能更有效率、更具成本效益地推動 AI 專案落地。以下將帶你在 5 分鐘內快速掌握這次更新的關鍵亮點與對企業的實際意義。
容器與 Kubernetes
在 GKE 上使用 Ray 為 Cloud TPU 提供更原生的體驗
工程團隊使用 Ray 在包括 GPU 和雲端 TPU 在內的各種硬體上擴展 AI 工作負載。雖然 Ray 提供了核心的擴展能力,但開發人員通常需要管理每個加速器的獨特架構細節。對於雲端 TPU 而言,這包括其特定的網路模型和單程式多數據 (SPMD) 程式設計風格。
作為與 Anyscale 合作的一部分,我們正致力於降低在 Google Kubernetes Engine (GKE) 上使用 TPU 的工程難度。我們的目標是讓 Ray 在 TPU 上的體驗盡可能原生且流暢。
今天,我們將推出幾項關鍵改進措施,以幫助實現這一目標。
Ray TPU 庫用於提高 Ray Core 中的 TPU 感知和擴展性
TPU 具有獨特的架構和一種稱為 SPMD 的特定程式設計風格。大型 AI 作業在 TPU 切片上運行,TPU 切片是由透過稱為晶片間互連 (ICI) 的高速網路連接的晶片集合而成。.jpg)
以前,需要手動配置 Ray 才能使其識別這種特定的硬體拓撲結構。這是一個重要的設定步驟,如果配置不當,作業可能會從不同的、不相連的分區獲取碎片化的資源,從而導致嚴重的效能瓶頸。
這個新函式庫ray.util.tpu抽象化了這些硬體細節。它利用名為 `Multislice` 的特性SlicePlacementGroup以及新的label_selectorAPI,自動將整個位於同一位置的 TPU 切片作為一個原子單元保留。這保證了作業在統一的硬體上運行,避免了因碎片化而導致的效能問題。由於 Ray 之前無法保證這種單切片原子性,因此建立可靠的真正多切片訓練(有意跨越多個不同的切片)是不可能的。這個新的 API 也為 Ray 用戶使用 Multislice 技術透過多個 TPU 切片進行擴充提供了關鍵基礎。
擴展了對 Jax、Ray Train 和 Ray Serve 的支持
我們的開發工作涵蓋了訓練和推理兩個面向。在訓練方面,Ray Train 現在提供對 TPU 上 JAX(透過JaxTrainer)和 PyTorch 的 alpha 支援。
此JaxTrainerAPI 簡化了在多主機 TPU 上執行 JAX 工作負載的過程。現在它可以自動處理複雜的分散式主機初始化。如下面的程式碼範例所示,您只需在一個簡單的ScalingConfig物件中定義硬體需求,例如工作節點數量、拓撲結構和加速器類型。JaxTrainer其餘部分將由 API 自動完成。
這是一項重大改進,因為它解決了關鍵的效能問題:資源分散化。先前,請求「4x4」拓撲(必須在稱為「切片」的單一共置硬體單元上運行)的作業可能會獲得碎片化的資源——例如,來自一個物理切片的八個晶片和來自另一個不相連切片的八個晶片。這種碎片化是一個主要的瓶頸,因為它阻止了工作負載使用僅存在於單一統一切片內的高速ICI互連。
JaxTrainer 如何簡化多主機 TPU 上的訓練範例:
import jax
import jax.numpy as jnp
import optax
import ray.train
from ray.train.v2.jax import JaxTrainer
from ray.train import ScalingConfig
def train_func():
"""This function is run on each distributed worker."""
...
# Define the hardware configuration for your distributed job.
scaling_config = ScalingConfig(
num_workers=4,
use_tpu=True,
topology="4x4",
accelerator_type="TPU-V6E",
placement_strategy="SPREAD"
)
# Define and run the JaxTrainer.
trainer = JaxTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
)
result = trainer.fit()
print(f"Training finished on TPU v6e 4x4 slice")
基於標籤的調度 API,易於獲取
新的基於標籤的調度 API與GKE自訂計算類別 整合。自訂計算類別是一種定義命名硬體配置的簡單方法。例如,您可以建立一個名為 `TPU-V6E` 的類,cost-optimized該類別指示 GKE 首先嘗試取得 Spot 實例,然後回退到動態工作負載調度器FlexStart 實例,最後才回退到預留實例。新的 Ray API 可讓您直接從 Python 中使用類別。只需一個簡單的 `--require-class` 語句label_selector,您就可以要求諸如“TPU-V6E”之類的硬體或指定目標cost-optimized類,而無需管理單獨的 YAML 檔案。
此label_selector 機制也為 TPU 提供了深層的硬體控制。當 GKE 為某個切片配置 TPU Pod 時,它會將元資料(例如工作節點等級和拓撲結構)注入到每個 Pod 中。 KubeRay(在 GKE 上管理 Ray)隨後會讀取這些 GKE 提供的元數據,並在創建節點時自動將其轉換為 Ray 特有的標籤。這提供了關鍵訊息,例如 TPU 代數(ray.io/accelerator-type)、實體晶片拓撲結構(ray.io/tpu-topology)以及切片內的工作節點等級(ray.io/tpu-worker-id )。
這些節點標籤可讓您使用 Ray label_selector 將 SPMD 工作負載固定到特定的、共置的硬件,例如「4x4」拓撲或特定的工作進程。
在下列範例中,Ray 使用者可以請求 v6e-32 TPU 實例,但指示 GKE 使用自訂計算類,如果 v6e-32 實例不可用,則回退到 v5e-16 實例。類似地,使用者可以先要求競價型或 DWS 資源,如果這些資源不可用,則回退到預留實例。
|
開發者選擇計算資源和節點池。 |
平台管理員設定 Kubernetes |
|
@ray.remote(num_cpu=1, |
apiVersion: cloud.google.com/v1 - tpu: |
TPU 指標和日誌集中在一個地方
現在,您可以在 Ray 控制面板中直接查看關鍵的 TPU 效能指標,例如 TensorCore 使用率、佔空比、高頻寬記憶體 (HBM) 使用率和記憶體頻寬利用率。我們還新增了底層libtpu日誌。這大大加快了偵錯速度,因為您可以立即檢查故障是由程式碼還是 TPU 硬體本身造成的。
立即開始
這些更新共同推動了 TPU 與 Ray 生態系統的無縫融合,使 TPU 成為 Ray 生態系統的重要組成部分。它們讓現有 Ray 應用在 GPU 和 TPU 之間的適應變得更加簡單。以下是了解更多資訊和入門的方法:
-
查閱文件:
-
JAX 工作負載:請參閱新的JAX 入門指南,以了解如何使用 JaxTrainer 並了解有關 JaxTrain 的更多資訊。
-
TPU 指標:在 Ray Dashboard 或 Grafana 中查看 TPU 指標
-
請求 TPU 容量:使用DWS Flex Start for TPUs快速開始使用,該服務可為運行時間少於 7 天的作業提供 TPU 存取權限。
- 相關內容:TPU簡介
從雲端到 AI,蓋亞資訊給你最完整的企業級支持
在面對 AI 需求快速攀升的企業環境中,Ray + Cloud TPU 的原生整合為組織提供了一個更靈活、更高效的基礎架構選擇。無論是模型訓練速度、運算成本控管、資源調度效率,或是未來的多切片(Multi-slice)擴展,本次更新都讓企業更容易構建可持續發展的 AI 能力。
蓋亞資訊作為 Google Cloud 官方代理商與技術合作夥伴,我們協助企業導入 Cloud TPU、GKE、分散式訓練架構(如 Ray),並提供技術諮詢、架構最佳化與雲端服務整合。如果你希望評估 TPU、GKE 或 AI 基礎架構升級,我們非常樂意協助深入分析與規劃。
Google Cloud Platform Premier Partner┃蓋亞資訊
蓋亞資訊是GCP官方認證的菁英合作夥伴,提供全面的混合雲端服務,包括雲端遷移、IT 環境架構設計、安全性、大數據解決方案以及雲端部署。不僅擁有超過百張認證的專業工程師團隊,還提供全年無休的 7X24 維運服務,隨時為企業解決問題。
