Source code for kgsaf_jdex.loaders.pytorch.dataset

#!/usr/bin/env python3

from pathlib import Path

import json
import torch
from rdflib import OWL, URIRef
from torch.utils.data import Dataset

import kgsaf_jdex.utils.conventions.ids as idc
import kgsaf_jdex.utils.conventions.paths as pc


[docs] class KnowledgeGraph(Dataset): """ Knowledge Graph Dataset Loader. This module defines a PyTorch `Dataset` abstraction for ontology-enhanced knowledge graphs, supporting ABox, TBox, and RBox components. The dataset structure assumes: - Precomputed URI-to-ID mappings for individuals, classes, and object properties - ABox triples split into train/validation/test sets - Optional class assertions (rdf:type) - TBox taxonomy (class subsumption) - RBox axioms including object property hierarchies and domain/range constraints All symbolic components are loaded from disk and exposed as `torch.Tensor` objects suitable for downstream learning and reasoning tasks. The implementation supports limited OWL constructs such as `owl:unionOf` in domain and range definitions while skipping unsupported complex expressions. """ def __init__( self, path: str, ): """ Initialize the knowledge graph dataset. Args: path (str): Base path of the dataset directory. """ super().__init__() # Dataset BASE Folder self.base_path = Path(path).resolve().absolute() # Mapping URIS to IDs and Back self._individual_to_id = self._load_mappings(pc.INDIVIDUAL_MAPPINGS) self._class_to_id = self._load_mappings(pc.CLASS_MAPPINGS) self._obj_prop_to_id = self._load_mappings(pc.OBJ_PROP_MAPPINGS) self._class_to_id[str(OWL.Thing)] = idc.THING self._class_to_id[str("http://schema.org/Thing")] = idc.THING self._id_to_individual = {v: k for k, v in self._individual_to_id.items()} self._id_to_class = {v: k for k, v in self._class_to_id.items()} self._id_to_obj_prop = {v: k for k, v in self._obj_prop_to_id.items()} # ABox self._train_triples = self._load_abox_triples(pc.TRAIN) self._test_triples = self._load_abox_triples(pc.TEST) self._valid_triples = self._load_abox_triples(pc.VALID) self._class_assertions = self._load_abox_class_assertions() # TBox self._taxonomy = self._load_tbox_taxonomy() # RBox self._obj_prop_domain_range = self._load_rbox_domain_range() self._obj_prop_hierarchy = self._load_rbox_hierarchy() # General Functions def _warning(self, count:int): """ Print a warning for skipped complex URIs. Args: count (int): Number of skipped URIs. """ print(f"WARNING: {count} complex URIs detected. These classes are skipped during loading.") def _load_mappings(self, file_location: str): """ Load a URI-to-ID mapping from disk. Args: file_location (str): Relative path to mapping file. """ with open(self.base_path / file_location, "r") as map_json: return json.load(map_json)
[docs] def individual_to_id(self, individual_uri: str) -> int: """ Convert individual URI to ID. """ return self._individual_to_id[individual_uri]
[docs] def class_to_id(self, class_uri: str) -> int: """ Convert class URI to ID. """ return self._class_to_id[class_uri]
[docs] def obj_prop_to_id(self, obj_prop_uri: str) -> int: """ Convert object property URI to ID. """ return self._obj_prop_to_id[obj_prop_uri]
[docs] def id_to_individual(self, individual_id: int) -> str: """ Convert individual ID to URI. """ return self._id_to_individual[individual_id]
[docs] def id_to_class(self, class_id: int) -> str: """ Convert class ID to URI. """ return self._id_to_class[class_id]
[docs] def id_to_obj_prop(self, obj_prop_id: int) -> str: """ Convert object property ID to URI. """ return self._id_to_obj_prop[obj_prop_id]
[docs] def individual_classes(self, individual_id: int) -> torch.tensor: """ Get classes asserted for an individual. """ return self.class_assertions[self.class_assertions[:, 0] == individual_id, 1]
[docs] def sup_classes(self, class_id: int) -> torch.tensor: """ Get superclasses of a class. """ return self.taxonomy[self.taxonomy[:, 0] == class_id, 1]
[docs] def sub_classes(self, class_id: int) -> torch.tensor: """ Get subclasses of a class. """ return self.taxonomy[self.taxonomy[:, 1] == class_id, 0]
[docs] def is_leaf(self, class_id: int) -> bool: """ Check if class has no subclasses. """ return len(self.sub_classes(class_id)) == 0
[docs] def sup_obj_prop(self, obj_prop_id: int) -> torch.tensor: """ Get super-properties of an object property. """ return self.obj_props_hierarchy[ self.obj_props_hierarchy[:, 0] == obj_prop_id, 1 ]
[docs] def sub_obj_prop(self, obj_prop_id: int) -> torch.tensor: """ Get sub-properties of an object property. """ return self.obj_props_hierarchy[ self.obj_props_hierarchy[:, 1] == obj_prop_id, 0 ]
[docs] def obj_prop_domain(self, obj_prop_id: int) -> torch.tensor: """ Get domain classes of an object property. """ return self.obj_props_domain[self.obj_props_domain[:, 0] == obj_prop_id, 1]
[docs] def obj_prop_range(self, obj_prop_id: int) -> torch.tensor: """ Get range classes of an object property. """ return self.obj_props_range[self.obj_props_range[:, 0] == obj_prop_id, 1]
# Getters @property def dataset_location(self) -> str: return str(self.base_path) @property def train(self) -> torch.tensor: return self._train_triples @property def valid(self) -> torch.tensor: return self._valid_triples @property def test(self) -> torch.tensor: return self._test_triples @property def class_assertions(self) -> torch.tensor: return self._class_assertions @property def taxonomy(self) -> torch.tensor: return self._taxonomy @property def obj_props_hierarchy(self) -> torch.tensor: return self._obj_prop_hierarchy @property def obj_props_domain(self) -> torch.tensor: return self._obj_prop_domain_range[ self._obj_prop_domain_range[:, 1] == idc.DOMAIN ][:, [0, 2]] @property def obj_props_range(self) -> torch.tensor: return self._obj_prop_domain_range[ self._obj_prop_domain_range[:, 1] == idc.RANGE ][:, [0, 2]] @property def obj_props_domains_range(self) -> torch.tensor: return self._obj_prop_domain_range # ABOX Loading Functions def _load_abox_triples(self, file_location: str): triples = [] with open(self.base_path / file_location, "r") as triples_txt: for t in triples_txt.readlines(): triple_split = t.strip().split("\t") s = self.individual_to_id(triple_split[0]) p = self.obj_prop_to_id(triple_split[1]) o = self.individual_to_id(triple_split[2]) triples.append([s, p, o]) return torch.tensor(triples, dtype=torch.int64) def _load_abox_class_assertions(self): if not (self.base_path / pc.CLASS_ASSERTIONS).exists(): return torch.tensor([]) casrt = [] with open(self.base_path / pc.CLASS_ASSERTIONS, "r") as casrt_json: data = json.load(casrt_json) for ind_uri in data: if ind_uri in self._individual_to_id.keys(): ind_id = self.individual_to_id(ind_uri) for class_uri in data[ind_uri]: class_id = self.class_to_id(class_uri) casrt.append([ind_id, class_id]) return torch.tensor(casrt, dtype=torch.int64) # TBOX Loading Functions def _load_tbox_taxonomy(self): if not (self.base_path / pc.TAXONOMY).exists(): return torch.tensor([]) taxonomy = [] complex_uri = 0 with open(self.base_path / pc.TAXONOMY, "r") as taxonomy_json: data = json.load(taxonomy_json) for c_uri in data: c_id = self.class_to_id(c_uri) for sup_c_uri in data[c_uri]: if isinstance(sup_c_uri, dict): complex_uri += 1 else: sup_c_id = self.class_to_id(sup_c_uri) taxonomy.append([c_id, sup_c_id]) if complex_uri > 0: self._warning(complex_uri) return torch.tensor(taxonomy, dtype=torch.int64) # RBOX Loading Functions def _load_rbox_domain_range(self): if not (self.base_path / pc.OBJ_PROP_HIERARCHY).exists(): return torch.tensor([]) complex_uri = 0 with open(self.base_path / pc.OBJ_PROP_DOMAIN_RANGE, "r") as role_dm_json: data = json.load(role_dm_json) dm = [] for r_uri in data: r_id = self.obj_prop_to_id(r_uri) domain, ex_d = self._compute_domain_range(data[r_uri]["domain"]) range, ex_r = self._compute_domain_range(data[r_uri]["range"]) complex_uri += ex_d complex_uri += ex_r for id in domain: dm.append([r_id, idc.DOMAIN, id]) for id in range: dm.append([r_id, idc.RANGE, id]) if complex_uri > 0: self._warning(complex_uri) return torch.tensor(dm, dtype=torch.int64) def _compute_domain_range(self, subdata): out = [] excluded = 0 for elem in subdata: if type(elem) is dict and str(OWL.unionOf) in elem.keys(): for unionclass in elem[str(OWL.unionOf)]: out.append(self.class_to_id(unionclass)) elif type(elem) is dict: excluded += 1 else: out.append(self.class_to_id(elem)) return out, excluded def _load_rbox_hierarchy(self): if not (self.base_path / pc.OBJ_PROP_HIERARCHY).exists(): return torch.tensor([]) rh = [] complex_uri = 0 with open(self.base_path / pc.OBJ_PROP_HIERARCHY, "r") as role_h_json: data = json.load(role_h_json) for r_uri in data: r_id = self.obj_prop_to_id(r_uri) for sup_r_uri in data[r_uri]: if type(sup_r_uri) is dict: complex_uri += 1 else: rh.append([r_id, self.obj_prop_to_id(sup_r_uri)]) if complex_uri > 0: self._warning(complex_uri) return torch.tensor(rh, dtype=torch.int64)