PhD Proposal: Towards Generalized and Scalable Representation Learning on Graphs
Graph Neural Networks (GNNs) have emerged as powerful architectures for learning and analyzing graph representations. However, the training of GNNs on large-scale datasets usually suffers from overfitting, and realistic graph datasets often involve a high volume of out-of-distribution test nodes, posing significant generalization challenges for prediction problems. Meanwhile, conventional GNNs are hindered by scalability problem when deployed on industrial-level graph datasets. In this proposal, we investigate algorithms and techniques to address the generalization and scalability issues of GNNs.In the first work, we propose to leverage data augmentation to generalize GNNs. Data augmentation helps neural networks generalize better by enlarging the training set, but it remains an open question how to effectively augment graph data to enhance the performance of GNNs. While most existing graph regularizers focus on manipulating graph topological structures by adding/removing edges, we offer a method to augment node features for better performance. We propose FLAG (Free Large-scale Adversarial Augmentation on Graphs), which iteratively augments node features with gradient-based adversarial perturbations during training. By making the model invariant to small fluctuations in input data, our method helps models generalize to out-of-distribution samples and boosts model performance at test time.In the second work, we carefully investigate the out-of-distribution (OOD) problem on graph data. We curate GDS, a benchmark of eight datasets reflecting a diverse range of distribution shifts across graphs. We observe that: (1) most domain generalization algorithms fail to work when applied to domain shifts on graphs; and (2) combinations of powerful GNN models and augmentation techniques usually achieve the best out-of-distribution performance.In the third work, we look into GNNs’ scalability problem. To scale GNNs to large graphs, various neighbor-, layer-, or subgraph-sampling techniques are proposed to alleviate the neighbor explosion problem. However, those methods show unstable performance for different tasks and datasets, and do not speed up model inference. We propose VQ-GNN, a universal framework to scale up any convolution-based GNNs using Vector Quantization (VQ) without compromising the performance. Experiments demonstrate the scalability and competitive performance of our framework on large-scale node classification and link prediction benchmarks.
Dr. Tom Goldstein
Dr. John Dickerson
Dr. Sanghamitra Dutta