smart.train

smart.train.laplacian_regularization(x, edge_index)

Compute Laplacian regularization loss.

This loss enforces smoothness by penalizing differences between embeddings of adjacent nodes.

Parameters:
  • x (torch.Tensor) – Node embeddings of shape [num_nodes, emb_dim].

  • edge_index (torch.LongTensor) – Graph connectivity in COO format with shape [2, num_edges].

Returns:

Scalar Laplacian regularization loss.

Return type:

torch.Tensor

smart.train.train_SMART(features, edges, triplet_samples_list, weights=[1, 1], emb_dim=64, n_epochs=500, lr=0.0001, weight_decay=1e-05, device=device(type='cpu'), window_size=20, slope=0.0001, Conv_Encoder=<class 'smart.layer.SAGEConv_Encoder'>, Conv_Decoder=<class 'smart.layer.SAGEConv_Decoder'>, margin=0.5, return_loss=False, laplacian_alpha=0)

Train the SMART model with reconstruction, triplet, and optional Laplacian loss.

Parameters:
  • features (list of torch.Tensor) – Node feature matrices for each modality. Each element has shape [num_nodes, in_dim].

  • edges (list of torch.LongTensor) – Graph connectivity for each modality. Each element has shape [2, num_edges].

  • triplet_samples_list (list of tuple) – Each tuple contains (anchors, positives, negatives) indices for triplet loss.

  • weights (list of float, default=[1, 1, 1, 1]) –

    Loss weights in the following order (matching the implementation): [reconstruction_loss_modality1, reconstruction_loss_modality2, triplet_loss_modality1, triplet_loss_modality2].

    • reconstruction_loss_modalityX: weight for reconstruction (MSE) loss of modality X

    • triplet_loss_modalityX: weight for triplet loss of modality X

    Notes

    Let M be the number of modalities (len(features)). weights should have length 2*M:

    weights[0:M] -> reconstruction weights for each modality weights[M:2*M] -> triplet weights for each modality

  • emb_dim (int, default=64) – Dimension of shared latent embedding.

  • n_epochs (int, default=500) – Number of training epochs.

  • lr (float, default=0.0001) – Learning rate for Adam optimizer.

  • weight_decay (float, default=1e-5) – Weight decay for optimizer.

  • device (torch.device, optional) – Device to train on (default: GPU if available).

  • window_size (int, default=20) – Window size for early stopping slope detection.

  • slope (float, default=0.0001) – Minimum absolute slope threshold for continuing training.

  • Conv_Encoder (class, default=SAGEConv_Encoder) – Graph encoder class.

  • Conv_Decoder (class, default=SAGEConv_Decoder) – Graph decoder class.

  • margin (float, default=0.5) – Margin for triplet loss.

  • return_loss (bool, default=False) – Whether to return loss history along with trained model.

  • laplacian_alpha (float, default=0) – Weight for Laplacian regularization term. Disabled if set to 0.

Returns:

  • model (SMART) – Trained SMART model.

  • loss_list (list of tuple, optional) – Only returned if return_loss=True. Contains (total_loss, tri_loss, rec_loss) per epoch.

smart.train.train_SMART_MS(features, edges, triplet_samples_list, weights=[1, 1], emb_dim=64, n_epochs=500, lr=0.0001, weight_decay=1e-05, device=device(type='cpu'), window_size=20, slope=0.0001, Conv_Encoder=<class 'smart.layer.SAGEConv_Encoder'>, Conv_Decoder=<class 'smart.layer.SAGEConv_Decoder'>, margin=0.5, return_loss=False, laplacian_alpha=0)

Train the SMART-MS model with reconstruction, triplet, and optional Laplacian loss.

Parameters:
  • features (list of torch.Tensor) – Node feature matrices for each modality. Each element has shape [num_nodes, in_dim].

  • edges (list of torch.LongTensor) – Graph connectivity for each modality. Each element has shape [2, num_edges].

  • triplet_samples_list (list of tuple) – Each tuple contains (anchors, positives, negatives) indices for triplet loss.

  • weights (list of float, default=[1, 1, 1, 1]) –

    Loss weights in the following order (matching the implementation): [reconstruction_loss_modality1, reconstruction_loss_modality2, triplet_loss_modality1, triplet_loss_modality2].

    • reconstruction_loss_modalityX: weight for reconstruction (MSE) loss of modality X

    • triplet_loss_modalityX: weight for triplet loss of modality X

    Notes

    Let M be the number of modalities (len(features)). weights should have length 2*M:

    weights[0:M] -> reconstruction weights for each modality weights[M:2*M] -> triplet weights for each modality

  • emb_dim (int, default=64) – Dimension of shared latent embedding.

  • n_epochs (int, default=500) – Number of training epochs.

  • lr (float, default=0.0001) – Learning rate for Adam optimizer.

  • weight_decay (float, default=1e-5) – Weight decay for optimizer.

  • device (torch.device, optional) – Device to train on (default: GPU if available).

  • window_size (int, default=20) – Window size for early stopping slope detection.

  • slope (float, default=0.0001) – Minimum absolute slope threshold for continuing training.

  • Conv_Encoder (class, default=SAGEConv_Encoder) – Graph encoder class.

  • Conv_Decoder (class, default=SAGEConv_Decoder) – Graph decoder class.

  • margin (float, default=0.5) – Margin for triplet loss.

  • return_loss (bool, default=False) – Whether to return loss history along with trained model.

  • laplacian_alpha (float, default=0) – Weight for Laplacian regularization term. Disabled if set to 0.

Returns:

  • model (SMART) – Trained SMART model.

  • loss_list (list of tuple, optional) – Only returned if return_loss=True. Contains (total_loss, tri_loss, rec_loss) per epoch.