From 5ac02bfda4b3ff0700344a1b8757ec3d586ea7a5 Mon Sep 17 00:00:00 2001 From: David Luevano <55825613+luevano@users.noreply.github.com> Date: Thu, 12 Dec 2019 19:20:24 -0700 Subject: Reformat test size --- do_ml.py | 18 ++++++++++++------ main.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/do_ml.py b/do_ml.py index 25cf04e01..63a6fc671 100644 --- a/do_ml.py +++ b/do_ml.py @@ -30,7 +30,7 @@ from cholesky_solve import cholesky_solve def do_ml(desc_data, energy_data, training_size, - test_size, + test_size=None, sigma=1000.0, desc_type=None, show_msgs=True): @@ -39,7 +39,8 @@ def do_ml(desc_data, desc_data: descriptor (or representation) data. energy_data: energy data associated with desc_data. training_size: size of the training set to use. - test_size: size of the test set to use. + test_size: size of the test set to use. If no size is given, + the last remaining molecules are used. sigma: depth of the kernel. desc_type: string with the name of the descriptor used. show_msgs: Show debug messages or not. @@ -47,20 +48,25 @@ def do_ml(desc_data, Also, training is done with the first part of the data and testing with the ending part of the data. """ + # Initial calculations for later use. + d_len = len(desc_data) + e_len = len(energy_data) + if not desc_type: desc_type = 'NOT SPECIFIED' - d_len = len(desc_data) - e_len = len(energy_data) if d_len != e_len: printc(''.join(['ERROR. Descriptor data size different ', 'than energy data size.']), 'RED') return None - if training_size > d_len or test_size > d_len: - printc('ERROR. Training or test size greater than data size.', 'RED') + if training_size >= d_len: + printc('ERROR. Training size greater or equal than data size.', 'RED') return None + if not test_size: + test_size = d_len - training_size + tic = time.perf_counter() if show_msgs: printc('{} ML started, with parameters:'.format(desc_type), 'CYAN') diff --git a/main.py b/main.py index c3da14bed..88734d57f 100644 --- a/main.py +++ b/main.py @@ -54,8 +54,18 @@ cm_data = c_matrix_multiple(molecules, nuclear_charge, as_eig=True) ljm_data = lj_matrix_multiple(molecules, nuclear_charge, as_eig=True) # ML calculation. -do_ml(cm_data, energy_pbe0, 1000, 100, sigma=1000.0, desc_type='CM') -do_ml(ljm_data, energy_pbe0, 1000, 100, sigma=1000.0, desc_type='L-JM') +do_ml(cm_data, + energy_pbe0, + 1000, + test_size=100, + sigma=1000.0, + desc_type='CM') +do_ml(ljm_data, + energy_pbe0, + 1000, + test_size=100, + sigma=1000.0, + desc_type='L-JM') # End of program end_time = time.perf_counter() -- cgit v1.2.3-70-g09d2