新智元报道
编辑:好困 小咸鱼
为什么要用GNN?
TF-GNN的结构
使用示例
import tensorflow as tf
import tensorflow_gnn as tfgnn
# Model hyper-parameters:
h_dims = {'user': 256, 'movie': 64, 'genre': 128}
# Model builder initialization:
gnn = tfgnn.keras.ConvGNNBuilder(
lambda edge_set_name: WeightedSumConvolution(),
lambda node_set_name: tfgnn.keras.layers.NextStateFromConcat(
tf.keras.layers.Dense(h_dims[node_set_name]))
)
# Two rounds of message passing to target node sets:
model = tf.keras.models.Sequential([
gnn.Convolve({'genre'}), # sends messages from movie to genre
gnn.Convolve({'user'}), # sends messages from movie and genre to users
tfgnn.keras.layers.Readout(node_set_name="user"),
tf.keras.layers.Dense(1)
])
class WeightedSumConvolution(tf.keras.layers.Layer):
"""Weighted sum of source nodes states."""
def call(self, graph: tfgnn.GraphTensor,
edge_set_name: tfgnn.EdgeSetName) -> tfgnn.Field:
messages = tfgnn.broadcast_node_to_edges(
graph,
edge_set_name,
tfgnn.SOURCE,
feature_name=tfgnn.DEFAULT_STATE_NAME)
weights = graph.edge_sets[edge_set_name]['weight']
weighted_messages = tf.expand_dims(weights, -1) * messages
pooled_messages = tfgnn.pool_edges_to_node(
graph,
edge_set_name,
tfgnn.TARGET,
reduce_type='sum',
feature_value=weighted_messages)
return pooled_messages
安装说明
$> git clone https://github.com/tensorflow/gnn.git tensorflow_gnn
$> pip install tensorflow
> sudo apt-get install graphviz graphviz-dev
$> cd tensorflow_gnn && python3 -m pip install
参考资料:
https://blog.tensorflow.org/2021/11/introducing-tensorflow-gnn.html?m=1
https://github.com/tensorflow/gnn