summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ml_exp/kernels.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/ml_exp/kernels.py b/ml_exp/kernels.py
index d593d83fd..3318fe6cf 100644
--- a/ml_exp/kernels.py
+++ b/ml_exp/kernels.py
@@ -21,6 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import numpy as np
+from scipy.stats import wasserstein_distance
try:
import tensorflow as tf
TF_AV = True
@@ -141,3 +142,28 @@ def laplacian_kernel(X1,
K[i, :] = np.exp(i_sigma * norm)
return K
+
+
+def wasserstein_kernel(X1,
+ X2,
+ alpha):
+ """
+ Calculates the Wasserstein Kernel.
+ X1: first representations.
+ X2: second representations.
+ alpha: wasserstein kernel parameter.
+ NOTE: this doesn't work with tensorflow.
+ """
+
+ if X2.ndim == 3:
+ raise TypeError('Representations must be 1D.')
+
+ X1_size = X1.shape[0]
+ X2_size = X2.shape[0]
+
+ K = np.zeros((X1_size, X2_size), dtype=np.float64)
+ for i in range(X1_size):
+ norm = np.array([X2[j] - X1[i] for j in range(X2_size)], dtype=np.float64)
+ K[i, :] = np.exp(- alpha * norm)
+
+ return K