smart.model

class smart.model.SMART(hidden_dims, device, Conv_Encoder=<class 'smart.layer.SAGEConv_Encoder'>, Conv_Decoder=<class 'smart.layer.SAGEConv_Decoder'>)

Bases: Module

SMART: A modular multi-modal graph representation learning model.

This model uses an encoder-decoder architecture for each modality, projects the learned embeddings into a shared latent space, and reconstructs input features for self-supervised training.

Parameters:
  • hidden_dims (list of int) – List specifying input dimensions of each modality and shared hidden dimension. Example: [in_dim_mod1, in_dim_mod2, …, latent_dim].

  • device (torch.device) – Device to place model modules on.

  • Conv_Encoder (class) – Encoder architecture (default: SAGEConv_Encoder).

  • Conv_Decoder (class) – Decoder architecture (default: SAGEConv_Decoder).

forward(features, edge_indexs)

Forward pass of the SMART model.

Parameters:
  • features (list of torch.Tensor) – Node features for each modality. Each tensor shape: [num_nodes, in_dim_mod].

  • edge_indexs (list of torch.LongTensor) – Graph connectivity for each modality. Each tensor shape: [2, num_edges].

Returns:

  • z (torch.Tensor) – Latent shared representation of shape [num_nodes, latent_dim].

  • x_rec (list of torch.Tensor) – Reconstructed features for each modality.

training: bool