smart.train
- smart.train.laplacian_regularization(x, edge_index)
Compute Laplacian regularization loss.
This loss enforces smoothness by penalizing differences between embeddings of adjacent nodes.
- Parameters:
x (torch.Tensor) – Node embeddings of shape [num_nodes, emb_dim].
edge_index (torch.LongTensor) – Graph connectivity in COO format with shape [2, num_edges].
- Returns:
Scalar Laplacian regularization loss.
- Return type:
torch.Tensor
- smart.train.train_SMART(features, edges, triplet_samples_list, weights=[1, 1], emb_dim=64, n_epochs=500, lr=0.0001, weight_decay=1e-05, device=device(type='cpu'), window_size=20, slope=0.0001, Conv_Encoder=<class 'smart.layer.SAGEConv_Encoder'>, Conv_Decoder=<class 'smart.layer.SAGEConv_Decoder'>, margin=0.5, return_loss=False, laplacian_alpha=0)
Train the SMART model with reconstruction, triplet, and optional Laplacian loss.
- Parameters:
features (list of torch.Tensor) – Node feature matrices for each modality. Each element has shape [num_nodes, in_dim].
edges (list of torch.LongTensor) – Graph connectivity for each modality. Each element has shape [2, num_edges].
triplet_samples_list (list of tuple) – Each tuple contains (anchors, positives, negatives) indices for triplet loss.
weights (list of float, default=[1, 1, 1, 1]) –
Loss weights in the following order (matching the implementation): [reconstruction_loss_modality1, reconstruction_loss_modality2, triplet_loss_modality1, triplet_loss_modality2].
reconstruction_loss_modalityX: weight for reconstruction (MSE) loss of modality X
triplet_loss_modalityX: weight for triplet loss of modality X
Notes
Let M be the number of modalities (len(features)). weights should have length 2*M:
weights[0:M] -> reconstruction weights for each modality weights[M:2*M] -> triplet weights for each modality
emb_dim (int, default=64) – Dimension of shared latent embedding.
n_epochs (int, default=500) – Number of training epochs.
lr (float, default=0.0001) – Learning rate for Adam optimizer.
weight_decay (float, default=1e-5) – Weight decay for optimizer.
device (torch.device, optional) – Device to train on (default: GPU if available).
window_size (int, default=20) – Window size for early stopping slope detection.
slope (float, default=0.0001) – Minimum absolute slope threshold for continuing training.
Conv_Encoder (class, default=SAGEConv_Encoder) – Graph encoder class.
Conv_Decoder (class, default=SAGEConv_Decoder) – Graph decoder class.
margin (float, default=0.5) – Margin for triplet loss.
return_loss (bool, default=False) – Whether to return loss history along with trained model.
laplacian_alpha (float, default=0) – Weight for Laplacian regularization term. Disabled if set to 0.
- Returns:
model (SMART) – Trained SMART model.
loss_list (list of tuple, optional) – Only returned if return_loss=True. Contains (total_loss, tri_loss, rec_loss) per epoch.
- smart.train.train_SMART_MS(features, edges, triplet_samples_list, weights=[1, 1], emb_dim=64, n_epochs=500, lr=0.0001, weight_decay=1e-05, device=device(type='cpu'), window_size=20, slope=0.0001, Conv_Encoder=<class 'smart.layer.SAGEConv_Encoder'>, Conv_Decoder=<class 'smart.layer.SAGEConv_Decoder'>, margin=0.5, return_loss=False, laplacian_alpha=0)
Train the SMART-MS model with reconstruction, triplet, and optional Laplacian loss.
- Parameters:
features (list of torch.Tensor) – Node feature matrices for each modality. Each element has shape [num_nodes, in_dim].
edges (list of torch.LongTensor) – Graph connectivity for each modality. Each element has shape [2, num_edges].
triplet_samples_list (list of tuple) – Each tuple contains (anchors, positives, negatives) indices for triplet loss.
weights (list of float, default=[1, 1, 1, 1]) –
Loss weights in the following order (matching the implementation): [reconstruction_loss_modality1, reconstruction_loss_modality2, triplet_loss_modality1, triplet_loss_modality2].
reconstruction_loss_modalityX: weight for reconstruction (MSE) loss of modality X
triplet_loss_modalityX: weight for triplet loss of modality X
Notes
Let M be the number of modalities (len(features)). weights should have length 2*M:
weights[0:M] -> reconstruction weights for each modality weights[M:2*M] -> triplet weights for each modality
emb_dim (int, default=64) – Dimension of shared latent embedding.
n_epochs (int, default=500) – Number of training epochs.
lr (float, default=0.0001) – Learning rate for Adam optimizer.
weight_decay (float, default=1e-5) – Weight decay for optimizer.
device (torch.device, optional) – Device to train on (default: GPU if available).
window_size (int, default=20) – Window size for early stopping slope detection.
slope (float, default=0.0001) – Minimum absolute slope threshold for continuing training.
Conv_Encoder (class, default=SAGEConv_Encoder) – Graph encoder class.
Conv_Decoder (class, default=SAGEConv_Decoder) – Graph decoder class.
margin (float, default=0.5) – Margin for triplet loss.
return_loss (bool, default=False) – Whether to return loss history along with trained model.
laplacian_alpha (float, default=0) – Weight for Laplacian regularization term. Disabled if set to 0.
- Returns:
model (SMART) – Trained SMART model.
loss_list (list of tuple, optional) – Only returned if return_loss=True. Contains (total_loss, tri_loss, rec_loss) per epoch.