from pykeen.datasets.base import Dataset
from pykeen.triples import TriplesFactory
from pathlib import Path
import json
from extension.constants import (
ENTITY_TO_ID_FILENAME,
RELATION_TO_ID_FILENAME,
TRAIN_SPLIT_FILENAME,
TEST_SPLIT_FILENAME,
VALID_SPLIT_FILENAME,
CLASS_MEMBERSHIP_METADATA_FILENAME,
DOMAIN_RANGE_METATDATA_FILENAME,
)
[docs]
class OnMemoryDataset(Dataset):
"""Dataset located on memory, requires already splitted data in RDF triple
format. The folder should contain the following files
### Folder Structure
- train.txt : Training triples in "h r t" format using RDF names
- test.txt : Testing triples in "h r t" format using RDF names
- valid.txt : Validation triples in "h r t" format using RDF names
- entity_to_id.json: JSON file for entity name to ID mapping
- relation_to_id.json: JSON file for relation name to ID mapping
- entities_classes.json : Additional metadata of class memebership for each entity, need to have format
```json
{
"<ENTITY_NAME>" : [
"<CLASS_NAME_1>"
...
"<CLASS_NAME_N>"
]
}
```
- relation_domain_range.json : Additional metadata of domain and range classes for each relation, needs to have format:
```json
{
"<RELATION_NAME>" : {
"domain" : "<CLASS_NAME_DOMAIN>" OR "None"
"range" : "<CLASS_NAME_RANGE>" OR "None"
}
}
```
"""
def __init__(
self,
data_path: str | Path = None,
load_entity_classes: bool = True,
load_domain_range: bool = True,
**kwargs
):
"""Initialize dataset from on disk folder
Args:
data_path (str | Path, optional): Dataset folder path. Defaults to None.
load_entity_classes (bool, optional): Load the entity class memebership metadata. Defaults to True.
load_domain_range (bool, optional): Load the relation domain and range classes metadata. Defaults to True.
"""
self.data_path = Path(data_path)
with open(self.data_path / ENTITY_TO_ID_FILENAME, "r") as f:
entity_id_mapping = json.load(f)
with open(self.data_path / RELATION_TO_ID_FILENAME, "r") as f:
relation_id_mapping = json.load(f)
self.training = TriplesFactory.from_path(
path=self.data_path / TRAIN_SPLIT_FILENAME,
create_inverse_triples=False,
entity_to_id=entity_id_mapping,
relation_to_id=relation_id_mapping,
)
self.testing = TriplesFactory.from_path(
path=self.data_path / TEST_SPLIT_FILENAME,
create_inverse_triples=False,
entity_to_id=entity_id_mapping,
relation_to_id=relation_id_mapping,
)
self.validation = TriplesFactory.from_path(
path=self.data_path / VALID_SPLIT_FILENAME,
create_inverse_triples=False,
entity_to_id=entity_id_mapping,
relation_to_id=relation_id_mapping,
)
if load_entity_classes:
self.entity_id_to_classes = self._load_entity_classes()
if load_domain_range:
self.relation_id_to_domain_range = self._load_relation_domain_range()
def _load_entity_classes(self) -> dict:
"""Load the entity class membership metadata from the provided JSON file.
Entity names are trasfomed to IDs.
Returns:
dict: Dictionary of entity id to list of class names
"""
with open(self.data_path / CLASS_MEMBERSHIP_METADATA_FILENAME, "r") as f:
data = json.load(f)
return {self.entity_to_id[k]: v for k, v in data.items()}
def _load_relation_domain_range(self) -> dict:
"""Load the relation domain and range classes from the provided JSON file.
Relation names are transformed to IDs.
Returns:
dict: Dictionary of relation is to dict with domain and range classes
"""
with open(self.data_path / DOMAIN_RANGE_METATDATA_FILENAME, "r") as f:
data = json.load(f)
return {
self.relation_to_id[k]: v
for k, v in data.items()
if k in self.relation_to_id.keys()
}