Advanced usage
Custom manifold
To correct each sample \(X_{i\cdot}\) in the source domain, the adapter uses the following function:
where f
is a differentiable function that maps each row of an input matrix onto a manifold, and g
is a multi-layer perceptron taking p-dimensional vectors as input and producing vectors of the same size. Intuitively, g
is a model used to explicitly learn the bias between the source and the target domains.
To implement a new manifold, both \(f\) and \(f^{-1}\) should be defined. \(f\) and \(f^{-1}\) correspond to the _transform
and _inverse_transform
abstract methods from the dagip.retraction.base.Manifold
base class.
Let’s take as example the multinomial manifold, namely the manifold of matrices with positive elements and having their rows summing up to 1 each:
import torch
from dagip.retraction.base import Manifold
class ProbabilitySimplex(Manifold):
def __init__(self, eps: float = 1e-6):
# Constant used to prevent numerical issues
self.eps: float = eps
def _transform(self, X: torch.Tensor) -> torch.Tensor:
# Project X onto the multinomial manifold
return torch.softmax(X, dim=1)
def _inverse_transform(self, X: torch.Tensor) -> torch.Tensor:
X = torch.clamp(X, self.eps, 1) # Avoid numerical issues
return torch.log(X) # Project X back in the Euclidean space
_transform
projects the data from the Euclidean space to the given manifold, while _inverse_transform
performs the reverse mapping, from the manifold to the Euclidean space. Let’s note that in the given example, the logarithm is indeed the inverse of the softmax operation, since f^{-1}
is called before f
, and X
is assumed to have its rows summing to 1 beforehand. The assumption that X
is already on the manifold can be exploited to easily implement otherwise non-invertible functions.
Custom distance metric
Implementing custom metrics is more straightforward, as only the pairwise_distances
method needs to be defined. Because X
and Y
are of dimensions (n, p) and (m, p), respectively, the output of the method should be a PyTorch tensor of shape (n, m). The method should be differentiable.
import torch
from dagip.spatial.base import BaseDistance
class ManhattanDistance(BaseDistance):
def pairwise_distances(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
return torch.cdist(X, Y, p=1)