Source code for tensor_theorem_prover.similarity

from __future__ import annotations
from typing import Callable, Iterable, Union

# optional dependency numpy
try:
    import numpy as np
    from numpy.linalg import norm

    has_numpy = True
except ImportError:
    has_numpy = False


from tensor_theorem_prover.types.Constant import Constant
from tensor_theorem_prover.types.Predicate import Predicate


SimilarityFunc = Callable[
    [Union[Constant, Predicate], Union[Constant, Predicate]], float
]


[docs]def symbol_compare(item1: Constant | Predicate, item2: Constant | Predicate) -> float: """ directly compares the symbol strings of the two items, doesn't do any fuzzy matching """ return 1.0 if item1.symbol == item2.symbol else 0.0
[docs]def cosine_similarity( item1: Constant | Predicate, item2: Constant | Predicate ) -> float: """ use cosine similarity to calculate a similarity score between the items. falls back to symbol comparison if either item is missing a embedding """ if item1.embedding is None or item2.embedding is None: return symbol_compare(item1, item2) if not has_numpy: raise ImportError("cosine_similarity requires numpy, but it is not installed") return np.dot(item1.embedding, item2.embedding) / ( norm(item1.embedding) * norm(item2.embedding) )
[docs]def max_similarity(funcs: Iterable[SimilarityFunc]) -> SimilarityFunc: """ returns a function that calls all the given functions and returns the maximum similarity score """ return lambda item1, item2: max(func(item1, item2) for func in funcs)