Skip to content

GeorgeCao-HG/gnn-node-prediction-example

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Basic GNN for Node-Level Prediction

This project provides a self-contained, runnable example of a Graph Neural Network (GNN) for a node-level prediction task, built using PyTorch and PyTorch Geometric.

The goal is to demonstrate the end-to-end process of:

  1. Defining a GNN architecture.
  2. Generating a synthetic dataset of graphs.
  3. Training the model on this dataset.
  4. Evaluating the model's performance.

This code is intended as a practical learning tool for understanding the fundamentals of GNNs in a system design context.

Project Structure


.
├── main.py         \# Main script containing the GNN model, data generation, training, and evaluation logic.
└── README.md       \# This file.

Features

  • GNN Model: A simple GNN model with two GCNConv layers, designed for predicting features at each node of a graph.
  • Synthetic Data Generation: Includes a function to create a dummy dataset of random graphs, making the project runnable out-of-the-box without needing external data.
  • Training & Evaluation Loop: A standard PyTorch training loop demonstrates how to feed graph data to the model, calculate loss, and update weights.
  • Clear & Commented Code: The code is extensively commented to explain each step of the process, from data preparation to model inference.

Requirements

To run this project, you will need Python 3.7+ and the following libraries:

  • torch
  • torch_geometric

You can install PyTorch Geometric by following the official instructions on their website, which will ensure compatibility with your system's PyTorch and CUDA versions. A typical installation might look like this:

# Example installation - check PyTorch Geometric website for your specific system
pip install torch
pip install torch_geometric

How to Run

  1. Clone the repository or save the main.py file to your local machine.
  2. Install the required libraries as described above.
  3. Run the script from your terminal:
    python main.py

The script will automatically generate the data, build the model, train it for a set number of epochs, and print the training and test loss at each epoch.

Code Explanation

Data

The script generates a synthetic dataset where the task is to predict a 4-dimensional output vector for each node. The input features for each node are random, and the target y values are generated by applying a simple transformation to the input features, giving the model a pattern to learn.

Model Architecture (MyGNNModel)

The model consists of two GCNConv layers.

  1. The first layer transforms the initial node features into a richer, hidden representation.
  2. The second layer refines these representations further by passing information between nodes a second time.
  3. A final Linear layer acts as a prediction head, mapping the final node embeddings to the desired 4-dimensional output for each node.

This architecture is effective for tasks where the prediction for a node depends on the features of its local neighborhood (in this case, nodes up to 2 "hops" away).

About

Graph Neural Network (GNN) Node Prediction

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages