summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--do_ml.py30
1 files changed, 22 insertions, 8 deletions
diff --git a/do_ml.py b/do_ml.py
index b92b15a99..561b33904 100644
--- a/do_ml.py
+++ b/do_ml.py
@@ -36,25 +36,34 @@ def printc(text, color):
def do_ml(desc_data,
- desc_type,
energy_data,
training_size,
test_size,
- sigma=1000.0):
+ sigma=1000.0,
+ desc_type=None,
+ show_msgs=True):
"""
Does the ML methodology.
desc_data: descriptor (or representation) data.
- desc_type: string with the name of the descriptor used.
energy_data: energy data associated with desc_data.
training_size: size of the training set to use.
test_size: size of the test set to use.
sigma: depth of the kernel.
+ desc_type: string with the name of the descriptor used.
+ show_msgs: Show debug messages or not.
NOTE: desc_type is just a string and is only for identification purposes.
Also, training is done with the first part of the data and
testing with the ending part of the data.
"""
+ if not desc_type:
+ desc_type = 'NOT SPECIFIED'
+
tic = time.perf_counter()
- printc('{} ML started.'.format(desc_type), Fore.CYAN)
+ if show_msgs:
+ printc('{} ML started, with parameters:'.format(desc_type), Fore.CYAN)
+ printc('\tTraining size: {}'.format(training_size), Fore.BLUE)
+ printc('\tTest size: {}'.format(test_size), Fore.BLUE)
+ printc('\tSigma: {}'.format(sigma), Fore.BLUE)
Xcm_training = desc_data[:training_size]
Ycm_training = energy_data[:training_size]
@@ -66,9 +75,14 @@ def do_ml(desc_data,
Kcm_test = gauss_kernel(Xcm_test, Xcm_training, sigma)
Ycm_predicted = np.dot(Kcm_test, alpha_cm)
- print('\tMAE for {}: {}'.format(desc_type,
- np.mean(np.abs(Ycm_predicted - Ycm_test))))
+ mae = np.mean(np.abs(Ycm_predicted - Ycm_test))
+ if show_msgs:
+ print('\tMAE for {}: {}'.format(desc_type, mae))
toc = time.perf_counter()
- printc('\t{} ML took {:.4f} seconds.'.format(desc_type, toc-tic),
- Fore.GREEN)
+ tictoc = toc - tic
+ if show_msgs:
+ printc('\t{} ML took {:.4f} seconds.'.format(desc_type, tictoc),
+ Fore.GREEN)
+
+ return mae, tictoc