被稱為“圖神經網絡”(GNN)的深度學習技術在圖域中運行。 這些網絡最近在各種領域中得到了應用,包括計算機視覺、推薦系統和組合優化等等。
此外,這些網絡可用於表示複雜系統,包括社交網絡、蛋白質-蛋白質相互作用網絡、知識圖譜等多個研究領域。
與圖片等其他類型的數據相比,非歐幾里得空間是圖形數據運行的地方。 為了對節點進行分類、預測鏈接和聚類數據,使用了圖形分析。
在本文中,我們將檢查 Graph 神經網絡 詳細介紹其類型,並提供使用 PyTorch 的實際示例。
那麼,什麼是圖?
圖是一種由節點和頂點組成的數據結構。 各個節點之間的連接由頂點決定。 如果在節點中指示了方向,則稱該圖為有向圖; 否則,它是無向的。
圖的一個很好的應用是對不同個體之間的關係進行建模 社交網絡. 在處理複雜的情況時,例如鍊接和交換,圖表非常有用。
它們被推薦系統、語義分析、社交網絡分析和模式識別所採用
. 創建基於圖形的解決方案是一個全新的領域,它提供了對複雜和相互關聯的數據的深刻理解。
圖神經網絡
圖神經網絡是可以對圖數據格式進行操作的專用神經網絡類型。 圖嵌入和卷積神經網絡 (CNN) 對它們有重大影響。
圖神經網絡用於包括預測節點、邊和圖的任務。
- CNN 用於對圖像進行分類。 類似地,為了預測一個類,GNN 被應用於表示圖結構的像素網格。
- 使用遞歸神經網絡的文本分類。 GNN 也用於圖形架構,其中短語中的每個單詞都是一個節點。
為了預測節點、邊或完整圖,使用神經網絡來創建 GNN。 例如,節點級別的預測可以解決垃圾郵件檢測等問題。
鏈接預測是推薦系統中的典型案例,可能是邊緣預測問題的一個例子。
圖神經網絡類型
存在多種神經網絡類型,其中大多數都存在卷積神經網絡。 我們將在這一部分了解最著名的 GNN。
圖卷積網絡(GCN)
它們可與經典的 CNN 相媲美。 它通過查看附近的節點來獲取特徵。 GNN 使用激活函數在聚合節點向量並將輸出發送到密集層之後添加非線性。
它本質上由圖卷積、線性層和非學習器激活函數組成。 GCN 有兩個主要品種:光譜卷積網絡和空間卷積網絡。
圖自動編碼器網絡
它使用編碼器來學習如何表示圖,並使用解碼器來嘗試重建輸入圖。 有一個瓶頸層連接編碼器和解碼器。
由於自動編碼器在處理類平衡方面做得非常出色,因此它們經常用於鏈接預測。
遞歸圖神經網絡 (RGNN)
在多關係網絡中,單個節點有許多關係,它學習最佳擴散模式並可以管理圖。 為了增加平滑度並減少過度參數化,正則化器用於這種形式的圖神經網絡。
為了獲得更好的結果,RGNN 需要更少的處理能力。 它們用於文本生成、語音識別、機器翻譯、圖片描述、視頻標記和文本摘要。
門控神經圖網絡 (GGNN)
在涉及長期依賴任務時,它們的性能優於 RGNN。 通過在長期依賴關係上包含節點、邊和時間門,門控圖神經網絡增強了循環圖神經網絡。
門的功能類似於門控循環單元 (GRU),因為它們用於在各個階段回憶和忘記數據。
使用 Pytorch 實現圖神經網絡
我們將關注的具體問題是一個常見的節點分類問題。 我們有一個相當大的社交網絡,叫做 musae-github,它是從開放 API 編譯的,供 GitHub 開發人員使用。
邊表示節點之間的相互追隨者關係,代表在至少 10 個存儲庫中加註星標的開發人員(平台用戶)(注意,相互一詞表示無向關係)。
根據節點的位置、加星標的存儲庫、雇主和電子郵件地址,檢索節點特徵。 預測 GitHub 用戶是 Web 開發人員還是 機器學習開發者 是我們的任務。
每個用戶的職位是此定位功能的基礎。
安裝 PyTorch
首先,我們首先需要安裝 火炬. 您可以根據您的機器配置它 点击這裡. 這是我的:
導入模塊
現在,我們導入必要的模塊
導入和探索數據
下一步是讀取數據並繪製標籤文件中的前五行和後五行。
四列中只有兩列——節點的 id(即用戶)和 ml_target,如果用戶是機器學習社區的成員,則為 1,否則為 0——在這種情況下與我們相關。
鑑於只有兩個類,我們現在可以確定我們的任務是一個二元分類問題。
由於嚴重的類不平衡,分類器可以只假設哪個類占多數,而不是評估代表性不足的類,這使得類平衡成為另一個需要考慮的關鍵因素。
繪製直方圖(頻率分佈)揭示了一些不平衡,因為機器學習(標籤=1)的類比其他類少。
特徵編碼
節點的特徵告訴我們與每個節點相關的特徵。 通過實現我們的數據編碼方法,我們可以立即對這些特徵進行編碼。
我們想利用這種方法封裝一小部分網絡(比如 60 個節點)以供顯示。 代碼在此處列出。
設計和顯示圖表
我們將使用火炬幾何。 數據來構建我們的圖表。
為了對具有不同(可選)屬性的單個圖進行建模,使用作為簡單 Python 對象的數據。 通過利用這個類和以下屬性——所有這些都是火炬張量——我們將創建我們的圖形對象。
將分配給編碼節點特徵的值 x 的形式是[節點數,特徵數]。
y 的形狀是[節點數],它將應用於節點標籤。
邊索引:為了描述無向圖,我們需要擴展原始邊索引,以允許存在兩條不同的有向邊,它們連接相同的兩個節點但指向相反的方向。
例如,在節點 100 和 200 之間需要一對邊,一個指向節點 200 到 100,另一個指向 100 到 200。如果提供了邊索引,那麼這就是無向圖的表示方式。 [2,2*number of original edges] 將是張量形式。
我們創建了繪製圖形方法來顯示圖形。 第一步是將我們的同構網絡轉換為 NetworkX 圖,然後可以使用 NetworkX.draw 進行繪製。
製作我們的 GNN 模型並進行訓練
我們首先通過使用 light=False 執行編碼數據來對整個數據集進行編碼,然後使用 light=False 調用構造圖來構建整個圖。 我們不會嘗試繪製這個大圖,因為我假設您使用的是資源有限的本地計算機。
掩碼是使用數字 0 和 1 標識哪些節點屬於每個特定掩碼的二進制向量,可用於通知訓練階段在訓練期間應包括哪些節點,並告訴推理階段哪些節點是測試數據。 火炬幾何變換。
可以使用 AddTrainValTestMask 類的訓練掩碼、驗證掩碼和測試掩碼屬性添加節點級拆分,這些屬性可用於獲取圖表並讓我們能夠指定我們希望如何構建掩碼。
我們只使用 10% 的數據進行訓練,使用 60% 的數據作為測試集,同時使用 30% 作為驗證集。
現在,我們將堆疊兩個 GCNConv 層,第一個層的輸出特徵計數等於我們圖中作為輸入特徵的特徵數量。
在第二層,它包含與我們的類數量相等的輸出節點,我們應用一個 relu 激活函數並提供潛在特徵。
邊索引和邊權重是 GCNConv 在前向函數中可以接受的眾多選項 x 中的兩個,但在我們的情況下,我們只需要前兩個變量。
儘管我們的模型將能夠預測圖中每個節點的類別,但我們仍然需要根據階段分別確定每個集合的準確度和損失。
例如,在訓練期間,我們只想利用訓練集來確定準確度和訓練損失,因此這就是我們的掩碼派上用場的地方。
為了計算適當的損失和準確率,我們將定義掩碼損失和掩碼準確率的函數。
訓練模型
現在我們已經定義了使用 Torch 的訓練目的。 Adam 是一位優化大師。
我們將進行一定數量的 epoch 的訓練,同時關注驗證的準確性。
我們還繪製了不同時期的訓練損失和準確性。
圖神經網絡的缺點
使用 GNN 有一些缺點。 何時使用 GNNa 以及如何提高我們的機器學習模型的性能,在我們對它們有了更好的了解之後,我們都會清楚地了解它們。
- 雖然 GNN 是淺層網絡,通常具有三層,但大多數神經網絡可以深入以提高性能。 由於這個限制,我們無法在大數據集上處於領先地位。
- 在圖上訓練模型更加困難,因為它們的結構動力學是動態的。
- 由於這些網絡的高計算成本,為生產擴展模型提出了挑戰。 如果您的圖形結構龐大且複雜,那麼將 GNN 用於生產將具有挑戰性。
結論
在過去的幾年裡,GNN 已經發展成為解決圖領域機器學習問題的強大而有效的工具。 本文給出了圖神經網絡的基本概述。
之後,您可以開始創建將用於訓練和測試模型的數據集。 要了解它的功能和能力,您還可以走得更遠,使用不同類型的數據集對其進行訓練。
編碼愉快!
發表評論