Source code for extension.filtering

import torch
from pykeen.sampling.filtering import PythonSetFilterer


[docs] class NullPythonSetFilterer(PythonSetFilterer): """Extensiion of Python Set based filterer that also check for manually inserted invalid negatives entities with negative indices. """ def __init__(self, mapped_triples): super().__init__(mapped_triples)
[docs] def contains(self, batch): return torch.as_tensor( data=[ ((-1 in tuple(triple)) or (tuple(triple) in self.triples)) for triple in batch.view(-1, 3).tolist() ], dtype=torch.bool, device=batch.device, ).view(*batch.shape[:-1])