summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Luevano <55825613+luevano@users.noreply.github.com>2019-12-12 19:20:24 -0700
committerDavid Luevano <55825613+luevano@users.noreply.github.com>2019-12-12 19:20:24 -0700
commit5ac02bfda4b3ff0700344a1b8757ec3d586ea7a5 (patch)
treee8c6b3b5e427a1f52eaea8e803c556240726be96
parent79315ef03abeb43dfb5c4cb1b68e1a33349f2bfb (diff)
Reformat test size
-rw-r--r--do_ml.py18
-rw-r--r--main.py14
2 files changed, 24 insertions, 8 deletions
diff --git a/do_ml.py b/do_ml.py
index 25cf04e01..63a6fc671 100644
--- a/do_ml.py
+++ b/do_ml.py
@@ -30,7 +30,7 @@ from cholesky_solve import cholesky_solve
def do_ml(desc_data,
energy_data,
training_size,
- test_size,
+ test_size=None,
sigma=1000.0,
desc_type=None,
show_msgs=True):
@@ -39,7 +39,8 @@ def do_ml(desc_data,
desc_data: descriptor (or representation) data.
energy_data: energy data associated with desc_data.
training_size: size of the training set to use.
- test_size: size of the test set to use.
+ test_size: size of the test set to use. If no size is given,
+ the last remaining molecules are used.
sigma: depth of the kernel.
desc_type: string with the name of the descriptor used.
show_msgs: Show debug messages or not.
@@ -47,20 +48,25 @@ def do_ml(desc_data,
Also, training is done with the first part of the data and
testing with the ending part of the data.
"""
+ # Initial calculations for later use.
+ d_len = len(desc_data)
+ e_len = len(energy_data)
+
if not desc_type:
desc_type = 'NOT SPECIFIED'
- d_len = len(desc_data)
- e_len = len(energy_data)
if d_len != e_len:
printc(''.join(['ERROR. Descriptor data size different ',
'than energy data size.']), 'RED')
return None
- if training_size > d_len or test_size > d_len:
- printc('ERROR. Training or test size greater than data size.', 'RED')
+ if training_size >= d_len:
+ printc('ERROR. Training size greater or equal than data size.', 'RED')
return None
+ if not test_size:
+ test_size = d_len - training_size
+
tic = time.perf_counter()
if show_msgs:
printc('{} ML started, with parameters:'.format(desc_type), 'CYAN')
diff --git a/main.py b/main.py
index c3da14bed..88734d57f 100644
--- a/main.py
+++ b/main.py
@@ -54,8 +54,18 @@ cm_data = c_matrix_multiple(molecules, nuclear_charge, as_eig=True)
ljm_data = lj_matrix_multiple(molecules, nuclear_charge, as_eig=True)
# ML calculation.
-do_ml(cm_data, energy_pbe0, 1000, 100, sigma=1000.0, desc_type='CM')
-do_ml(ljm_data, energy_pbe0, 1000, 100, sigma=1000.0, desc_type='L-JM')
+do_ml(cm_data,
+ energy_pbe0,
+ 1000,
+ test_size=100,
+ sigma=1000.0,
+ desc_type='CM')
+do_ml(ljm_data,
+ energy_pbe0,
+ 1000,
+ test_size=100,
+ sigma=1000.0,
+ desc_type='L-JM')
# End of program
end_time = time.perf_counter()