summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lj_matrix/do_ml.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/lj_matrix/do_ml.py b/lj_matrix/do_ml.py
index ba88a6fd8..bb954a0ae 100644
--- a/lj_matrix/do_ml.py
+++ b/lj_matrix/do_ml.py
@@ -79,17 +79,17 @@ def do_ml(desc_data,
printc('\tTest size: {}'.format(test_size), 'CYAN')
printc('\tSigma: {}'.format(sigma), 'CYAN')
- Xcm_training = desc_data[:training_size]
- Ycm_training = energy_data[:training_size]
- Kcm_training = gauss_kernel(Xcm_training, Xcm_training, sigma)
- alpha_cm = cholesky_solve(Kcm_training, Ycm_training)
+ X_training = desc_data[:training_size]
+ Y_training = energy_data[:training_size]
+ K_training = gauss_kernel(X_training, X_training, sigma)
+ alpha_ = cholesky_solve(K_training, Y_training)
- Xcm_test = desc_data[-test_size:]
- Ycm_test = energy_data[-test_size:]
- Kcm_test = gauss_kernel(Xcm_test, Xcm_training, sigma)
- Ycm_predicted = np.dot(Kcm_test, alpha_cm)
+ X_test = desc_data[-test_size:]
+ Y_test = energy_data[-test_size:]
+ K_test = gauss_kernel(X_test, X_training, sigma)
+ Y_predicted = np.dot(K_test, alpha_)
- mae = np.mean(np.abs(Ycm_predicted - Ycm_test))
+ mae = np.mean(np.abs(Y_predicted - Y_test))
if show_msgs:
printc('\tMAE for {}: {:.4f}'.format(desc_type, mae), 'GREEN')