grafx.utils

get_node_ids_from_type(G: GRAFX, node_type: str)

Retrieves the node IDs for of a specific type in the graph.

Parameters:
  • G (GRAFX) – The target graph.

  • node_type (str) – The node type to retrieve.

Returns:

A list of node IDs that match the given type.

Return type:

List[int]

count_nodes_per_type(G: GRAFX, types_to_count: list | None = None)

Counts the number of nodes for each specified type in the graph.

Parameters:
  • G (GRAFX) – The target graph.

  • types_to_count (list, optional) – A list of node types to count. If None, counts all types present in the graph (default: None).

Returns:

A dictionary with node types as keys and counts as values.

Return type:

Dict[str, int]

create_empty_parameters(processors, G, std=0.01)

Creates and initializes parameter tensors in a nested dictionary format from a given graph and processors. The tensors values are sampled from a normal distribution \(\mathcal{N}(0, \sigma^2)\), where the standard deviation \(\sigma\) is given by the std argument.

Parameters:
  • processors (Mappings) – A dictionary of processors, either dict or nn.ModuleDict, where keys are node types and values are processors.

  • G (GRAFX) – The graph containing nodes whose parameters are to be initialized.

  • std (float, optional) – Standard deviation for the parameter initialization (default: 0.01).

Returns:

A module dictionary with initialized parameters for each node type in the graph.

Return type:

nn.ModuleDict

permute_grafx_tensor(G_t, node_id, node_attrs=['node_types', 'rendering_orders'], id_attrs=['edge_indices'])

Permutes the node and edge attributes of a given GRAFXTensor according to a given node ordering. Attributes that are not provided in the node_attrs or id_attrs are left unchanged.

Parameters:
  • G_t (GRAFXTensor) – The graph tensor to permute.

  • node_id (LongTensor) – The permutation index given by the node IDs.

  • node_attrs (List[str], optional) – List of node attributes to permute (default: ["node_types", "rendering_orders"]).

  • id_attrs (List[str], optional) – List of attributes that contain node IDs (default: ["edge_indices"]).

Returns:

The permuted graph tensor.

Return type:

GRAFXTensor