summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-02-27 15:59:38 -0700
committerDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-02-27 15:59:38 -0700
commitb937df41c1c5e996be94a3a690908ea989e281dc (patch)
treeff8d8faa4f606afe27830df12f8124c3c333e34e
parenta8dc304d0c85b3db76986f109bdcd7a530bd3045 (diff)
Placeholder function, test
-rw-r--r--ml_exp/do_ml.py56
1 files changed, 54 insertions, 2 deletions
diff --git a/ml_exp/do_ml.py b/ml_exp/do_ml.py
index de4d4061c..f32480367 100644
--- a/ml_exp/do_ml.py
+++ b/ml_exp/do_ml.py
@@ -25,7 +25,7 @@ import numpy as np
from ml_exp.misc import printc
from ml_exp.kernels import gaussian_kernel
from ml_exp.math import cholesky_solve
-# from ml_exp.qm7db import qm7db
+from ml_exp.qm7db import qm7db
def simple_ml(descriptors,
@@ -44,7 +44,7 @@ def simple_ml(descriptors,
test_size: size of the test set to use. If no size is given,
the last remaining molecules are used.
sigma: depth of the kernel.
- show_msgs: if debu messages should be shown.
+ show_msgs: if debug messages should be shown.
NOTE: identifier is just a string and is only for identification purposes.
Also, training is done with the first part of the data and
testing with the ending part of the data.
@@ -97,3 +97,55 @@ def simple_ml(descriptors,
printc(f'\t\tSigma: {sigma}', 'CYAN')
return mae, tictoc
+
+
+def do_ml(db_path='data',
+ is_shuffled=True,
+ r_seed=111,
+ show_msgs=True):
+ """
+ Main function that does the whole ML process.
+ training_size: minimum training size.
+ test_size: size of the test set to use. If no size is given,
+ the last remaining molecules are used.
+ ljm_diag_value: if a special diagonal value should be used in lj matrix.
+ ljm_sigma: sigma value for lj matrix.
+ ljm_epsilon: epsilon value for lj matrix.
+ r_seed: random seed to use for the shuffling.
+ save_benchmarks: if benchmarks should be saved.
+ size: maximum amount of atoms in molecule.
+ as_eig: if data should be returned as matrix or array of eigenvalues.
+ bohr_radius_units: if units should be in bohr's radius units.
+ sigma: depth of the kernel.
+ show_msgs: if debug messages should be shown.
+ """
+ init_time = time.perf_counter()
+
+ # Data reading.
+ tic = time.perf_counter()
+ compounds, energy_pbe0, energy_delta = qm7db(db_path=db_path,
+ is_shuffled=is_shuffled,
+ r_seed=r_seed)
+ toc = time.perf_counter()
+ tictoc = toc - tic
+ if show_msgs:
+ printc(f'Data reading took {tictoc:.4f} seconds.', 'CYAN')
+
+ # Matrices calculation.
+ tic = time.perf_counter()
+ for compound in compounds:
+ compound.gen_cm()
+ compound.gen_ljm()
+ compound.gen_am()
+ toc = time.perf_counter()
+ tictoc = toc - tic
+ if show_msgs:
+ printc(f'Matrices calculation took {tictoc:.4f} seconds.', 'CYAN')
+
+ # ML calculation.
+ # PLHLDR
+
+ # End of program
+ end_time = time.perf_counter()
+ totaltime = end_time - init_time
+ printc(f'Program took {totaltime:.4f} seconds.', 'CYAN')