diff options
author | David Luevano Alvarado <55825613+luevano@users.noreply.github.com> | 2020-03-07 08:47:45 -0700 |
---|---|---|
committer | David Luevano Alvarado <55825613+luevano@users.noreply.github.com> | 2020-03-07 08:47:45 -0700 |
commit | 00301d1a9a8a7f975b64fe4ef85458f6a40776f7 (patch) | |
tree | 1cbe7e0ae1e40f67523d2eee5cd6c1f8c26cae43 /ml_exp/do_ml.py | |
parent | 9d33a74915382159ce97e3b4142743a7e3e1c72d (diff) |
Use AM again, add tf handler
Diffstat (limited to 'ml_exp/do_ml.py')
-rw-r--r-- | ml_exp/do_ml.py | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/ml_exp/do_ml.py b/ml_exp/do_ml.py index 5ff3c1105..f4742b66d 100644 --- a/ml_exp/do_ml.py +++ b/ml_exp/do_ml.py @@ -23,7 +23,12 @@ SOFTWARE. import time import numpy as np from scipy import linalg as LA -import tensorflow as tf +try: + import tensorflow as tf + TF_AV = True +except ImportError: + print('Tensorflow couldn\'t be imported. Maybe it is not installed.') + TF_AV = False from ml_exp.misc import printc from ml_exp.kernels import gaussian_kernel from ml_exp.qm7db import qm7db @@ -67,6 +72,10 @@ def simple_ml(descriptors, if training_size >= data_size: raise ValueError('Training size is greater or equal to the data size.') + # If tf is to be used but couldn't be imported, don't try to use it. + if use_tf and not TF_AV: + use_tf = False + # If test_size is not set, it is set to a maximum size of 1500. if not test_size: test_size = data_size - training_size @@ -186,6 +195,10 @@ def do_ml(db_path='data', if type(identifiers) != list: raise TypeError('\'identifiers\' is not a list.') + # If tf is to be used but couldn't be imported, don't try to use it. + if use_tf and not TF_AV: + use_tf = False + init_time = time.perf_counter() # Data reading. @@ -213,13 +226,11 @@ def do_ml(db_path='data', size=size, as_eig=as_eig, bohr_ru=bohr_ru) - """ if 'AM' in identifiers: compound.gen_hd(size=size, bohr_ru=bohr_ru) compound.gen_am(use_forces=use_forces, size=size) - """ if 'BOB' in identifiers: compound.gen_bob(size=size) @@ -228,10 +239,8 @@ def do_ml(db_path='data', cm_data = np.array([comp.cm for comp in compounds], dtype=np.float64) if 'LJM' in identifiers: ljm_data = np.array([comp.ljm for comp in compounds], dtype=np.float64) - """ if 'AM' in identifiers: am_data = np.array([comp.cm for comp in compounds], dtype=np.float64) - """ if 'BOB' in identifiers: bob_data = np.array([comp.bob for comp in compounds], dtype=np.float64) @@ -273,7 +282,6 @@ def do_ml(db_path='data', identifier='LJM', use_tf=use_tf, show_msgs=show_msgs) - """ if 'AM' in identifiers: am_mae, am_tictoc = simple_ml(am_data, energy_pbe0, @@ -283,7 +291,6 @@ def do_ml(db_path='data', identifier='AM', use_tf=use_tf, show_msgs=show_msgs) - """ if 'BOB' in identifiers: bob_mae, bob_tictoc = simple_ml(bob_data, energy_pbe0, |