summaryrefslogtreecommitdiff
path: root/do_ml.py
diff options
context:
space:
mode:
Diffstat (limited to 'do_ml.py')
-rw-r--r--do_ml.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/do_ml.py b/do_ml.py
index 25cf04e01..63a6fc671 100644
--- a/do_ml.py
+++ b/do_ml.py
@@ -30,7 +30,7 @@ from cholesky_solve import cholesky_solve
def do_ml(desc_data,
energy_data,
training_size,
- test_size,
+ test_size=None,
sigma=1000.0,
desc_type=None,
show_msgs=True):
@@ -39,7 +39,8 @@ def do_ml(desc_data,
desc_data: descriptor (or representation) data.
energy_data: energy data associated with desc_data.
training_size: size of the training set to use.
- test_size: size of the test set to use.
+ test_size: size of the test set to use. If no size is given,
+ the last remaining molecules are used.
sigma: depth of the kernel.
desc_type: string with the name of the descriptor used.
show_msgs: Show debug messages or not.
@@ -47,20 +48,25 @@ def do_ml(desc_data,
Also, training is done with the first part of the data and
testing with the ending part of the data.
"""
+ # Initial calculations for later use.
+ d_len = len(desc_data)
+ e_len = len(energy_data)
+
if not desc_type:
desc_type = 'NOT SPECIFIED'
- d_len = len(desc_data)
- e_len = len(energy_data)
if d_len != e_len:
printc(''.join(['ERROR. Descriptor data size different ',
'than energy data size.']), 'RED')
return None
- if training_size > d_len or test_size > d_len:
- printc('ERROR. Training or test size greater than data size.', 'RED')
+ if training_size >= d_len:
+ printc('ERROR. Training size greater or equal than data size.', 'RED')
return None
+ if not test_size:
+ test_size = d_len - training_size
+
tic = time.perf_counter()
if show_msgs:
printc('{} ML started, with parameters:'.format(desc_type), 'CYAN')