Skip to content

GNN

relationalai.semantics.reasoners.predictive.estimator
GNN(
*,
exp_database: str,
exp_schema: str,
database: Optional[str] = None,
schema: Optional[str] = None,
graph: Optional[Graph] = None,
property_transformer: Optional[PropertyTransformer] = None,
train: Optional[b.Relationship | b.Fragment | b.Chain] = None,
validation: Optional[b.Relationship | b.Fragment | b.Chain] = None,
source_concept: Optional[b.Concept] = None,
target_concept: Optional[b.Concept] = None,
task_type: Optional[str] = None,
eval_metric: Optional[str] = None,
use_current_time: bool = True,
has_time_column: Optional[bool] = None,
model_database: Optional[str] = None,
model_schema: Optional[str] = None,
model_name: Optional[str] = None,
version_name: Optional[str] = None,
model_run_id: Optional[str] = None,
test_batch_size: Optional[int] = None,
stream_logs: bool = True,
extract_embeddings: bool = False,
dataset_alias: Optional[str] = None,
parallel_reasoners_init: bool = True,
**train_params: Any
)

Train, load, register and predict with a Graph Neural Network.

GNN supports two workflows:

  • Fit workflow — provide graph, train, validation, and a task_type, then call GNN.fit followed by GNN.predictions.
  • Load workflow — provide a previously trained model via model_run_id or the four registry parameters (model_database, model_schema, model_name, version_name), then call GNN.load followed by GNN.predictions.

In either workflow you can register the trained model in the Snowflake Model Registry using the GNN.register_model method.

  • exp_database

    (str) - Snowflake database for experiment storage.
  • exp_schema

    (str) - Snowflake schema for experiment storage.
  • database

    (str, default: None) - Snowflake database to save predictions in.
  • schema

    (str, default: None) - Snowflake schema to save predictions in.
  • graph

    (Graph, default: None) - The knowledge graph with edges defined. Required for the fit workflow.
  • property_transformer

    (PropertyTransformer, default: None) - Column-level semantic type annotations. If omitted, all column types are auto-inferred.
  • train

    (Relationship or Fragment, default: None) - Training split relationship. Required for the fit workflow.
  • validation

    (Relationship or Fragment, default: None) - Validation split relationship. Required for the fit workflow.
  • source_concept

    (Concept, default: None) - Source concept for the load workflow (inferred from train in the fit workflow).
  • target_concept

    (Concept, default: None) - Target concept for link-prediction tasks in the load workflow (inferred from train in the fit workflow).
  • task_type

    (str, default: None) - One of "binary_classification", "multiclass_classification", "multilabel_classification", "regression", "link_prediction", or "repeated_link_prediction". Required for the fit workflow. In the load workflow, inferred automatically from the registered model if not provided.
  • eval_metric

    (str, default: None) - Evaluation metric compatible with the chosen task_type (e.g. "roc_auc", "accuracy", "rmse", "link_prediction_precision@5").
  • use_current_time

    (bool, default: True) - Use the current timestamp as the prediction time. Default is True.
  • has_time_column

    (bool, default: None) - Set to True when the task relationships use the at keyword for temporal ordering. In the load workflow, inferred automatically from the registered model if not provided.
  • model_database

    (str, default: None) - Snowflake database of a registered model (load workflow).
  • model_schema

    (str, default: None) - Snowflake schema of a registered model (load workflow).
  • model_name

    (str, default: None) - Name of the registered model (load workflow).
  • version_name

    (str, default: None) - Version of the registered model (load workflow).
  • model_run_id

    (str, default: None) - Run ID of a previously trained model (load workflow).
  • test_batch_size

    (int, default: None) - Batch size used during inference.
  • stream_logs

    (bool, default: True) - Stream training logs to stdout. Default is True.
  • extract_embeddings

    (bool, default: False) - Extract node embeddings during prediction. Default is False.
  • dataset_alias

    (str, default: None) - User chosen alias for the dataset.
  • parallel_reasoners_init

    (bool, default: True) - Initialize the Predictive and Logic reasoners in parallel. Default is True.
  • **train_params

    (Any, default: {}) - Additional hyperparameters forwarded to the GNN trainer (e.g. n_epochs, lr, train_batch_size, device, head_layers). Ignored in the load workflow.

Assuming the setup from the module-level Quick Start (relationalai.semantics.reasoners.predictive):

gnn = GNN(
exp_database="EXPERIMENTS_DB",
exp_schema="EXPERIMENTS_SCHEMA",
graph=gnn_graph,
property_transformer=property_transformer,
source_concept=Students,
train=Train,
validation=Validation,
task_type="binary_classification",
eval_metric="roc_auc",
device="cuda",
n_epochs=5,
)
gnn.fit()
Students.predictions = gnn.predictions(domain=Test)

Load a registered model and predict:

gnn = GNN(
exp_database="EXPERIMENTS_DB",
exp_schema="EXPERIMENTS_SCHEMA",
model_database="MODELS_DB",
model_schema="MODELS_SCHEMA",
model_name="MY_MODEL",
version_name="V1",
source_concept=Students,
)
gnn.load()
Students.predictions = gnn.predictions(domain=Test)
GNN.fit() -> GNN

Train the GNN model.

Materializes graph and task tables, creates a trainer, and submits a training job. If training has already completed or is in progress, calling fit() again is a no-op.

Returns:

  • GNN - This instance, so that calls can be chained (e.g. gnn.fit().predictions(domain=Test)).

Raises:

  • ValueError - If called in a load workflow.
GNN.load() -> GNN

Load a previously trained model for prediction.

Resolves the model from the Snowflake Model Registry (when model_database, model_schema, model_name, and version_name were provided) or by model_run_id, and prepares the trainer for prediction. If the model is already loaded, calling load() again is a no-op.

Returns:

  • GNN - This instance, so that calls can be chained (e.g. gnn.load().predictions(domain=Test)).

Raises:

  • ValueError - If called in a fit workflow, or if the specified model or version does not exist.
GNN.predictions(domain: b.Relationship | b.Fragment | b.Chain) -> b.Relationship

Generate predictions on a test domain.

Materializes the test table, submits a prediction job, and returns a Relationship that can be assigned to a concept field for downstream querying.

The prediction attributes available on the returned relationship depend on the task type:

  • Classification: .probs, .predicted_labels
  • Regression: .predicted_value
  • Link prediction: .rank, .scores, .predicted_<target>

Parameters:

  • domain

    (Relationship or Fragment or Chain) - The test split relationship (e.g. the Test relationship defined during data modeling).

Returns:

  • Relationship - A prediction relationship to be assigned to the source concept (e.g. User.predictions = gnn.predictions(domain=Test)).

Raises:

  • TypeError - If domain is not a Relationship, Fragment, or Chain.
  • ValueError - If the model has not been fitted or loaded, or if the test domain schema does not match the training schema.
GNN.register_model(
model_database: str,
model_schema: str,
model_name: str,
version_name: str,
*,
comment: Optional[str] = None
) -> None

Register a trained model in the Snowflake Model Registry.

After registration the model can be loaded in a later session by passing the same model_database, model_schema, model_name, and version_name to the GNN constructor.

Parameters:

  • model_database

    (str) - Snowflake database for the model registry entry.
  • model_schema

    (str) - Snowflake schema for the model registry entry.
  • model_name

    (str) - Name under which to register the model.
  • version_name

    (str) - Version label (e.g. "v1").
  • comment

    (str, default: None) - Free-text comment stored alongside the registry entry.

Raises:

  • ValueError - If the model is already registered, or has not been fitted / loaded.

Examples:

gnn.fit()
gnn.register_model(
model_database="MODELS_DB",
model_schema="MODELS_SCHEMA",
model_name="STUDENT_CHURN",
version_name="V1",
comment="First baseline model",
)
GNN.visualize_dataset(show_dtypes: bool = False)

Visualize the dataset graph.

Returns a graph object that can be rendered in a notebook to inspect the node and edge structure of the prepared dataset.

Parameters:

  • show_dtypes

    (bool, default: False) - Include column data types in the visualization. Default is False.

Returns:

  • object - A graph visualization object that can be rendered in a notebook.

Raises:

  • ValueError - If no dataset has been prepared yet (i.e. GNN.fit has not been called).

Examples:

from IPython.display import Image, display
gnn.fit()
graph = gnn.visualize_dataset(show_dtypes=True)
display(Image(graph.create_png()))
 semantics > reasoners > predictive > estimator
└──  GNN
    ├──  fit
    └──  load