From b14c581ca5fdab47d7e1c0b688331368cb7f29d0 Mon Sep 17 00:00:00 2001 From: David Luevano <55825613+luevano@users.noreply.github.com> Date: Mon, 23 Dec 2019 13:11:12 -0700 Subject: Refactor ml code --- lj_matrix/do_ml.py | 104 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 93 insertions(+), 11 deletions(-) (limited to 'lj_matrix') diff --git a/lj_matrix/do_ml.py b/lj_matrix/do_ml.py index bb954a0ae..ac044cfb3 100644 --- a/lj_matrix/do_ml.py +++ b/lj_matrix/do_ml.py @@ -22,19 +22,22 @@ SOFTWARE. """ import time import numpy as np +from multiprocessing import Process, Pipe from lj_matrix.misc import printc from lj_matrix.gauss_kernel import gauss_kernel from lj_matrix.cholesky_solve import cholesky_solve - - -def do_ml(desc_data, - energy_data, - training_size, - desc_type=None, - pipe=None, - test_size=None, - sigma=1000.0, - show_msgs=True): +from lj_matrix.read_qm7_data import read_qm7_data +from lj_matrix.parallel_create_matrices import parallel_create_matrices + + +def ml(desc_data, + energy_data, + training_size, + desc_type=None, + pipe=None, + test_size=None, + sigma=1000.0, + show_msgs=True): """ Does the ML methodology. desc_data: descriptor (or representation) data. @@ -51,6 +54,7 @@ def do_ml(desc_data, Also, training is done with the first part of the data and testing with the ending part of the data. """ + tic = time.perf_counter() # Initial calculations for later use. d_len = len(desc_data) e_len = len(energy_data) @@ -72,7 +76,6 @@ def do_ml(desc_data, if test_size > 1500: test_size = 1500 - tic = time.perf_counter() if show_msgs: printc('{} ML started.'.format(desc_type), 'GREEN') printc('\tTraining size: {}'.format(training_size), 'CYAN') @@ -106,3 +109,82 @@ def do_ml(desc_data, pipe.send([desc_type, training_size, test_size, sigma, mae, tictoc]) return mae, tictoc + + +# Test +def do_ml(min_training_size, + max_training_size=None, + training_increment_size=None, + ljm_sigma=1.0, + ljm_epsilon=1.0, + save_benchmarks=False): + """ + Main function that does the whole ML process. + min_training_size: minimum training size. + max_training_size: maximum training size. + training_increment_size: training increment size. + ljm_sigma: sigma value for lj matrix. + ljm_epsilon: epsilon value for lj matrix. + save_benchmarks: if benchmarks should be saved. + """ + # Initialization time. + init_time = time.perf_counter() + + # Data reading. + molecules, nuclear_charge, energy_pbe0, energy_delta = read_qm7_data() + + # Matrices calculation. + cm_data, ljm_data = parallel_create_matrices(molecules, + nuclear_charge, + ljm_sigma, + ljm_epsilon) + + # ML calculation. + procs = [] + cm_pipes = [] + ljm_pipes = [] + for i in range(min_training_size, + max_training_size + 1, + training_increment_size): + cm_recv, cm_send = Pipe(False) + p1 = Process(target=ml, + args=(cm_data, energy_pbe0, i, 'CM', cm_send)) + procs.append(p1) + cm_pipes.append(cm_recv) + p1.start() + + ljm_recv, ljm_send = Pipe(False) + p2 = Process(target=ml, + args=(ljm_data, energy_pbe0, i, 'L-JM', ljm_send)) + procs.append(p2) + ljm_pipes.append(ljm_recv) + p2.start() + + cm_bench_results = [] + ljm_bench_results = [] + for cd_pipe, ljd_pipe in zip(cm_pipes, ljm_pipes): + cm_bench_results.append(cd_pipe.recv()) + ljm_bench_results.append(ljd_pipe.recv()) + + for proc in procs: + proc.join() + + if save_benchmarks: + with open('data\\benchmarks.csv', 'a') as save_file: + # save_file.write(''.join(['ml_type,tr_size,te_size,kernel_s,', + # 'mae,time,lj_s,lj_e,date_ran\n'])) + ltime = time.localtime()[:3][::-1] + ljm_se = ',' + str(ljm_sigma) + ',' + str(ljm_epsilon) + ',' + date = '/'.join([str(field) for field in ltime]) + for cm, ljm, in zip(cm_bench_results, ljm_bench_results): + cm_text = ','.join([str(field) for field in cm])\ + + ',' + date + '\n' + ljm_text = ','.join([str(field) for field in ljm])\ + + ljm_se + date + '\n' + save_file.write(cm_text) + save_file.write(ljm_text) + + # End of program + end_time = time.perf_counter() + printc('Program took {:.4f} seconds.'.format(end_time - init_time), + 'CYAN') -- cgit v1.2.3-54-g00ecf