From 902c05813a8d83a917ca3608cbfe59afe5e10616 Mon Sep 17 00:00:00 2001 From: David Luevano Alvarado <55825613+luevano@users.noreply.github.com> Date: Thu, 26 Mar 2020 16:07:05 -0700 Subject: Fix wasserstein kernel --- ml_exp/kernels.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ml_exp/kernels.py b/ml_exp/kernels.py index 3318fe6cf..0af68f103 100644 --- a/ml_exp/kernels.py +++ b/ml_exp/kernels.py @@ -21,7 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import numpy as np -from scipy.stats import wasserstein_distance +from scipy.stats import wasserstein_distance as was_dist try: import tensorflow as tf TF_AV = True @@ -163,7 +163,8 @@ def wasserstein_kernel(X1, K = np.zeros((X1_size, X2_size), dtype=np.float64) for i in range(X1_size): - norm = np.array([X2[j] - X1[i] for j in range(X2_size)], dtype=np.float64) + norm = np.array([was_dist(X1[i], X2[j]) for j in range(X2_size)], + dtype=np.float64) K[i, :] = np.exp(- alpha * norm) return K -- cgit v1.2.3-70-g09d2