diff options
Diffstat (limited to 'ml_exp/do_ml.py')
-rw-r--r-- | ml_exp/do_ml.py | 60 |
1 files changed, 41 insertions, 19 deletions
diff --git a/ml_exp/do_ml.py b/ml_exp/do_ml.py index d8ee415bf..5efd13690 100644 --- a/ml_exp/do_ml.py +++ b/ml_exp/do_ml.py @@ -33,6 +33,7 @@ def simple_ml(descriptors, training_size, test_size=None, sigma=1000.0, + opt=True, identifier=None, show_msgs=True): """ @@ -43,6 +44,7 @@ def simple_ml(descriptors, test_size: size of the test set to use. If no size is given, the last remaining molecules are used. sigma: depth of the kernel. + opt: if the optimized algorithm should be used. For benchmarking purposes. identifier: string with the name of the descriptor used. show_msgs: if debug messages should be shown. NOTE: identifier is just a string and is only for identification purposes. @@ -76,13 +78,21 @@ def simple_ml(descriptors, X_training = descriptors[:training_size] Y_training = energies[:training_size] - K_training = gaussian_kernel(X_training, X_training, sigma) - alpha = cholesky_solve(K_training, Y_training) + K_training = gaussian_kernel(X_training, + X_training, + sigma, + opt=opt) + alpha = cholesky_solve(K_training, + Y_training) X_test = descriptors[-test_size:] Y_test = energies[-test_size:] - K_test = gaussian_kernel(X_test, X_training, sigma) - Y_predicted = np.dot(K_test, alpha) + K_test = gaussian_kernel(X_test, + X_training, + sigma, + opt=opt) + Y_predicted = np.dot(K_test, + alpha) mae = np.mean(np.abs(Y_predicted - Y_test)) if show_msgs: @@ -113,7 +123,8 @@ def do_ml(db_path='data', training_size=1500, test_size=None, sigma=1000.0, - identifiers=["CM"], + opt=True, + identifiers=['CM'], show_msgs=True): """ Main function that does the whole ML process. @@ -132,6 +143,7 @@ def do_ml(db_path='data', test_size: size of the test set to use. If no size is given, the last remaining molecules are used. sigma: depth of the kernel. + opt: if the optimized algorithm should be used. For benchmarking purposes. identifiers: list of names (strings) of descriptors to use. show_msgs: if debug messages should be shown. """ @@ -164,23 +176,27 @@ def do_ml(db_path='data', size=size, as_eig=as_eig, bohr_ru=bohr_ru) + """ if 'AM' in identifiers: - compound.gen_am(use_forces=use_forces, - size=size, + compound.gen_hd(size=size, bohr_ru=bohr_ru) - if 'BOS' in identifiers: - compound.gen_bos(size=size, - stuff=stuff) + compound.gen_am(use_forces=use_forces, + size=size) + """ + if 'BOB' in identifiers: + compound.gen_bob(size=size) # Create a numpy array for the descriptors. if 'CM' in identifiers: - cm_data = np.array([comp.cm for comp in compounds], dtype=float) + cm_data = np.array([comp.cm for comp in compounds], dtype=np.float64) if 'LJM' in identifiers: - ljm_data = np.array([comp.ljm for comp in compounds], dtype=float) + ljm_data = np.array([comp.ljm for comp in compounds], dtype=np.float64) + """ if 'AM' in identifiers: - am_data = np.array([comp.cm for comp in compounds], dtype=float) - if 'BOS' in identifiers: - bos_data = np.array([comp.bos for comp in compounds], dtype=float) + am_data = np.array([comp.cm for comp in compounds], dtype=np.float64) + """ + if 'BOB' in identifiers: + bob_data = np.array([comp.bob for comp in compounds], dtype=np.float64) toc = time.perf_counter() tictoc = toc - tic @@ -194,6 +210,7 @@ def do_ml(db_path='data', training_size=training_size, test_size=test_size, sigma=sigma, + opt=opt, identifier='CM', show_msgs=show_msgs) if 'LJM' in identifiers: @@ -202,23 +219,28 @@ def do_ml(db_path='data', training_size=training_size, test_size=test_size, sigma=sigma, + opt=opt, identifier='LJM', show_msgs=show_msgs) + """ if 'AM' in identifiers: am_mae, am_tictoc = simple_ml(am_data, energy_pbe0, training_size=training_size, test_size=test_size, sigma=sigma, - identifier='CM', + opt=opt, + identifier='AM', show_msgs=show_msgs) - if 'BOS' in identifiers: - bos_mae, bos_tictoc = simple_ml(bos_data, + """ + if 'BOB' in identifiers: + bob_mae, bob_tictoc = simple_ml(bob_data, energy_pbe0, training_size=training_size, test_size=test_size, sigma=sigma, - identifier='CM', + opt=opt, + identifier='BOB', show_msgs=show_msgs) # End of program |