Skip to content

glycontact.learning module

glycontact.learning

GINSweetNet(lib_size: int, num_classes: int = 1, hidden_dim: int = 128, num_components: int = 5)

Bases: Module

given glycan graphs as input, predicts properties via a graph neural network

forward(x: torch.Tensor, edge_index: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]

Forward pass through the model. Args: x: Input node features [batch_size, num_nodes, hidden_dim] edge_index: Edge indices for the graph [2, num_edges] Returns: Tuple of weights_logits: Logits for mixture weights [batch_size, 2, num_components] means: Mean angles in degrees [batch_size, 2, num_components] kappas: Concentration parameters [batch_size, 2, num_components] sasa_pred: Predicted SASA values [batch_size] flex_pred: Predicted flexibility values [batch_size]

VonMisesSweetNet(lib_size: int, num_classes: int = 1, hidden_dim: int = 128, num_components: int = 5)

Bases: Module

given glycan graphs as input, predicts properties via a graph neural network

forward(x: torch.Tensor, edge_index: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]

Forward pass through the model. Args: x: Input node features [batch_size, num_nodes, hidden_dim] edge_index: Edge indices for the graph [2, num_edges] Returns: Tuple of weights_logits: Logits for mixture weights [batch_size, 2, num_components] means: Mean angles in degrees [batch_size, 2, num_components] kappas: Concentration parameters [batch_size, 2, num_components] sasa_pred: Predicted SASA values [batch_size] flex_pred: Predicted flexibility values [batch_size]

predict_von_mises_parameters(x: torch.Tensor, head: torch.nn.Module, fc_weights: torch.nn.Module, fc_means: torch.nn.Module, fc_kappas: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Predict mixture parameters for a given input tensor. Args: x: Input tensor [batch_size, hidden_dim] head: Head module for the mixture model fc_weights: Fully connected layer for weights fc_means: Fully connected layer for means fc_kappas: Fully connected layer for kappas Returns: Tuple of weights_logits: Logits for mixture weights [batch_size, 2, num_components] means: Mean angles in degrees [batch_size, 2, num_components] kappas: Concentration parameters [batch_size, 2, num_components]

angular_rmse(predicted_graphs: list[nx.DiGraph], true_graphs: list[nx.DiGraph]) -> tuple[float, float]

Calculate the root mean square error (RMSE) for phi and psi angles. Args: predicted_graphs: List of predicted structure graphs true_graphs: List of true structure graphs Returns: Tuple of RMSE for phi and psi angles

build_baselines(data: list[nx.DiGraph], fn: callable = np.mean) -> tuple[callable, callable, callable, callable]

Build baseline functions to predict SASA, flexibility, phi, and psi angles based on monosaccharides. Args: data: List of structure graphs. fn: Function to aggregate values (e.g., np.mean, np.median). Returns: Tuple of functions for phi, psi, SASA, and flexibility.

clean_split(split: list[tuple[torch_geometric.data.Data, nx.DiGraph]], mode: Literal['mean', 'max'] = 'max') -> tuple[torch_geometric.data.Data, nx.DiGraph]

Clean the split data by condensing it to one conformer per glycan. Args: split (list): A list of tuples containing the PyTorch Geometric Data object and the structure graph. mode (str): The mode for condensing the data. "mean" for mean conformer, "max" for maximum weight conformer. Returns: list: A list of tuples containing the condensed PyTorch Geometric Data object and the structure graph.

create_dataset(fresh: bool = True, splits: list[float] = [0.8, 0.2])

Create a dataset of PyTorch Geometric Data objects from the structure graphs of glycans. Args: fresh (bool): If True, fetches the latest data. If False, uses cached data. splits (list): A list of two or three floats representing the train-test split ratios. Returns: tuple: A tuple containing the training and testing datasets.

encode_angles_sincos(angles_deg)

[..., N] degrees → [..., 2N] (sin/cos pairs for each angle)

eval_baseline(nxgraphs: list[nx.DiGraph], phi_pred: callable, psi_pred: callable, sasa_pred: callable, flex_pred: callable) -> list[nx.DiGraph]

Evaluate the baseline model by predicting angles and properties for each graph. Args: nxgraphs: List of structure graphs phi_pred: Function to predict phi angles psi_pred: Function to predict psi angles sasa_pred: Function to predict SASA flex_pred: Function to predict flexibility Returns: List of predicted structure graphs

evaluate_model(model: torch.nn.Module | tuple[callable, callable, callable, callable], structures, count: int = 10)

Evaluate the model by sampling angles and properties from the structure graphs. Args: model: The trained model. This can be a trained SweetNet or a tuple of baseline predictors for phi, psi, SASA, and flexibility. structures: List of structure graphs count: Number of samples to generate for each graph Returns: Tuple of RMSE values for phi, psi, SASA, and flexibility

get_all_structure_graphs(glycan, stereo=None, libr=None)

Get all structure graphs for a given glycan. Args: glycan (str): The glycan name. stereo (str, optional): The stereochemistry. If None, both alpha and beta are returned. libr (HashableDict, optional): A library of structures. If None, the default library is used. Returns: list: A list of tuples containing the PDB file name and the corresponding structure graph.

graph2pyg(g, weight, iupac, conformer)

Convert a structure graph to a PyTorch Geometric Data object. Args: g (networkx.Graph): The structure graph. weight (float): The weight of the graph. iupac (str): The IUPAC name of the glycan. conformer (str): The conformer name. Returns: torch_geometric.data.Data: The PyTorch Geometric Data object.

mean_conformer(conformers: list[tuple[float, tuple[torch_geometric.data.Data, nx.DiGraph]]]) -> tuple[torch_geometric.data.Data, nx.DiGraph]

Calculate the mean conformer from a list of conformers. Args: conformers (list): A list of tuples containing the weight and the structure graph. Returns: tuple: A tuple containing the mean PyTorch Geometric Data object and the mean structure graph.

mixture_von_mises_nll(angles: torch.Tensor, weights_logits: torch.Tensor, mus: torch.Tensor, kappas: torch.Tensor, label_smoothing: float = 0.0) -> tuple[torch.Tensor, torch.Tensor]

Negative log-likelihood for mixture of von Mises distributions Args: angles: True angles in degrees [batch_size, 2] (phi, psi) weights_logits: Raw logits for mixture weights [batch_size, 2, n_components] mus: Mean angles in degrees [batch_size, 2, n_components] kappas: Concentration parameters [batch_size, 2, n_components] Returns: Negative log-likelihood

node2y(attr)

Extract ML task labels from node attributes. Args: attr (dict): Node attributes. Returns: list: A list of labels for the node. If all labels are zero, returns None.

periodic_mse(pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Calculate the periodic mean squared error (MSE) for angles. Args: pred: Predicted angles in degrees [batch_size, 2] target: True angles in degrees [batch_size, 2 Returns: Tuple of MSE for phi and psi angles

periodic_rmse(pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]

Calculate the periodic root mean square error (RMSE) for angles. Args: pred: Predicted angles in degrees [batch_size, 2] target: True angles in degrees [batch_size, 2] Returns: Tuple of RMSE for phi and psi angles

sample_angle(weights: torch.Tensor, mus: torch.Tensor, kappas: torch.Tensor) -> torch.Tensor

Sample an angle from a mixture of von Mises distributions. Args: weights: Mixture weights [n_components] mus: Mean angles in degrees [n_components] kappas: Concentration parameters [n_components] Returns: Sampled angle in degrees

sample_from_model(model: torch.nn.Module, structures: list[torch_geometric.data.Data, nx.DiGraph], count: int = 10)

Sample from the model using the provided structures Args: model: The trained model structures: List of structure graphs Returns: List of sampled angles

value_rmse(predicted_graphs: list[nx.DiGraph], true_graphs: list[nx.DiGraph], name: Literal['SASA', 'flexibility']) -> float

Calculate the root mean square error (RMSE) for a specific property (SASA or flexibility). Args: predicted_graphs: List of predicted structure graphs true_graphs: List of true structure graphs name: The property to calculate RMSE for (e.g., "SASA" or "flexibility") Returns: RMSE value