diff options
author | David Luevano Alvarado <55825613+luevano@users.noreply.github.com> | 2020-02-27 15:59:38 -0700 |
---|---|---|
committer | David Luevano Alvarado <55825613+luevano@users.noreply.github.com> | 2020-02-27 15:59:38 -0700 |
commit | b937df41c1c5e996be94a3a690908ea989e281dc (patch) | |
tree | ff8d8faa4f606afe27830df12f8124c3c333e34e | |
parent | a8dc304d0c85b3db76986f109bdcd7a530bd3045 (diff) |
Placeholder function, test
-rw-r--r-- | ml_exp/do_ml.py | 56 |
1 files changed, 54 insertions, 2 deletions
diff --git a/ml_exp/do_ml.py b/ml_exp/do_ml.py index de4d4061c..f32480367 100644 --- a/ml_exp/do_ml.py +++ b/ml_exp/do_ml.py @@ -25,7 +25,7 @@ import numpy as np from ml_exp.misc import printc from ml_exp.kernels import gaussian_kernel from ml_exp.math import cholesky_solve -# from ml_exp.qm7db import qm7db +from ml_exp.qm7db import qm7db def simple_ml(descriptors, @@ -44,7 +44,7 @@ def simple_ml(descriptors, test_size: size of the test set to use. If no size is given, the last remaining molecules are used. sigma: depth of the kernel. - show_msgs: if debu messages should be shown. + show_msgs: if debug messages should be shown. NOTE: identifier is just a string and is only for identification purposes. Also, training is done with the first part of the data and testing with the ending part of the data. @@ -97,3 +97,55 @@ def simple_ml(descriptors, printc(f'\t\tSigma: {sigma}', 'CYAN') return mae, tictoc + + +def do_ml(db_path='data', + is_shuffled=True, + r_seed=111, + show_msgs=True): + """ + Main function that does the whole ML process. + training_size: minimum training size. + test_size: size of the test set to use. If no size is given, + the last remaining molecules are used. + ljm_diag_value: if a special diagonal value should be used in lj matrix. + ljm_sigma: sigma value for lj matrix. + ljm_epsilon: epsilon value for lj matrix. + r_seed: random seed to use for the shuffling. + save_benchmarks: if benchmarks should be saved. + size: maximum amount of atoms in molecule. + as_eig: if data should be returned as matrix or array of eigenvalues. + bohr_radius_units: if units should be in bohr's radius units. + sigma: depth of the kernel. + show_msgs: if debug messages should be shown. + """ + init_time = time.perf_counter() + + # Data reading. + tic = time.perf_counter() + compounds, energy_pbe0, energy_delta = qm7db(db_path=db_path, + is_shuffled=is_shuffled, + r_seed=r_seed) + toc = time.perf_counter() + tictoc = toc - tic + if show_msgs: + printc(f'Data reading took {tictoc:.4f} seconds.', 'CYAN') + + # Matrices calculation. + tic = time.perf_counter() + for compound in compounds: + compound.gen_cm() + compound.gen_ljm() + compound.gen_am() + toc = time.perf_counter() + tictoc = toc - tic + if show_msgs: + printc(f'Matrices calculation took {tictoc:.4f} seconds.', 'CYAN') + + # ML calculation. + # PLHLDR + + # End of program + end_time = time.perf_counter() + totaltime = end_time - init_time + printc(f'Program took {totaltime:.4f} seconds.', 'CYAN') |