summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-03-26 16:07:05 -0700
committerDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-03-26 16:07:05 -0700
commit902c05813a8d83a917ca3608cbfe59afe5e10616 (patch)
treec06210ffd343b8e9cdaa90f1d1abd74aeb753595
parent9b26afc10f47866ddb176d9e2c8609b9b180105d (diff)
Fix wasserstein kernel
-rw-r--r--ml_exp/kernels.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/ml_exp/kernels.py b/ml_exp/kernels.py
index 3318fe6cf..0af68f103 100644
--- a/ml_exp/kernels.py
+++ b/ml_exp/kernels.py
@@ -21,7 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import numpy as np
-from scipy.stats import wasserstein_distance
+from scipy.stats import wasserstein_distance as was_dist
try:
import tensorflow as tf
TF_AV = True
@@ -163,7 +163,8 @@ def wasserstein_kernel(X1,
K = np.zeros((X1_size, X2_size), dtype=np.float64)
for i in range(X1_size):
- norm = np.array([X2[j] - X1[i] for j in range(X2_size)], dtype=np.float64)
+ norm = np.array([was_dist(X1[i], X2[j]) for j in range(X2_size)],
+ dtype=np.float64)
K[i, :] = np.exp(- alpha * norm)
return K