summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-03-07 08:47:45 -0700
committerDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-03-07 08:47:45 -0700
commit00301d1a9a8a7f975b64fe4ef85458f6a40776f7 (patch)
tree1cbe7e0ae1e40f67523d2eee5cd6c1f8c26cae43
parent9d33a74915382159ce97e3b4142743a7e3e1c72d (diff)
Use AM again, add tf handler
-rw-r--r--ml_exp/do_ml.py21
-rw-r--r--ml_exp/kernels.py11
-rw-r--r--ml_exp/qm7db.py11
3 files changed, 34 insertions, 9 deletions
diff --git a/ml_exp/do_ml.py b/ml_exp/do_ml.py
index 5ff3c1105..f4742b66d 100644
--- a/ml_exp/do_ml.py
+++ b/ml_exp/do_ml.py
@@ -23,7 +23,12 @@ SOFTWARE.
import time
import numpy as np
from scipy import linalg as LA
-import tensorflow as tf
+try:
+ import tensorflow as tf
+ TF_AV = True
+except ImportError:
+ print('Tensorflow couldn\'t be imported. Maybe it is not installed.')
+ TF_AV = False
from ml_exp.misc import printc
from ml_exp.kernels import gaussian_kernel
from ml_exp.qm7db import qm7db
@@ -67,6 +72,10 @@ def simple_ml(descriptors,
if training_size >= data_size:
raise ValueError('Training size is greater or equal to the data size.')
+ # If tf is to be used but couldn't be imported, don't try to use it.
+ if use_tf and not TF_AV:
+ use_tf = False
+
# If test_size is not set, it is set to a maximum size of 1500.
if not test_size:
test_size = data_size - training_size
@@ -186,6 +195,10 @@ def do_ml(db_path='data',
if type(identifiers) != list:
raise TypeError('\'identifiers\' is not a list.')
+ # If tf is to be used but couldn't be imported, don't try to use it.
+ if use_tf and not TF_AV:
+ use_tf = False
+
init_time = time.perf_counter()
# Data reading.
@@ -213,13 +226,11 @@ def do_ml(db_path='data',
size=size,
as_eig=as_eig,
bohr_ru=bohr_ru)
- """
if 'AM' in identifiers:
compound.gen_hd(size=size,
bohr_ru=bohr_ru)
compound.gen_am(use_forces=use_forces,
size=size)
- """
if 'BOB' in identifiers:
compound.gen_bob(size=size)
@@ -228,10 +239,8 @@ def do_ml(db_path='data',
cm_data = np.array([comp.cm for comp in compounds], dtype=np.float64)
if 'LJM' in identifiers:
ljm_data = np.array([comp.ljm for comp in compounds], dtype=np.float64)
- """
if 'AM' in identifiers:
am_data = np.array([comp.cm for comp in compounds], dtype=np.float64)
- """
if 'BOB' in identifiers:
bob_data = np.array([comp.bob for comp in compounds], dtype=np.float64)
@@ -273,7 +282,6 @@ def do_ml(db_path='data',
identifier='LJM',
use_tf=use_tf,
show_msgs=show_msgs)
- """
if 'AM' in identifiers:
am_mae, am_tictoc = simple_ml(am_data,
energy_pbe0,
@@ -283,7 +291,6 @@ def do_ml(db_path='data',
identifier='AM',
use_tf=use_tf,
show_msgs=show_msgs)
- """
if 'BOB' in identifiers:
bob_mae, bob_tictoc = simple_ml(bob_data,
energy_pbe0,
diff --git a/ml_exp/kernels.py b/ml_exp/kernels.py
index c203af30e..abc71f7af 100644
--- a/ml_exp/kernels.py
+++ b/ml_exp/kernels.py
@@ -22,7 +22,12 @@ SOFTWARE.
"""
# import math
import numpy as np
-import tensorflow as tf
+try:
+ import tensorflow as tf
+ TF_AV = True
+except ImportError:
+ print('Tensorflow couldn\'t be imported. Maybe it is not installed.')
+ TF_AV = False
def gaussian_kernel(X1,
@@ -36,6 +41,10 @@ def gaussian_kernel(X1,
sigma: kernel width.
use_tf: if tensorflow should be used.
"""
+ # If tf is to be used but couldn't be imported, don't try to use it.
+ if use_tf and not TF_AV:
+ use_tf = False
+
X1_size = X1.shape[0]
X2_size = X2.shape[0]
i_sigma = -0.5 / (sigma*sigma)
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index c20df018e..d4734f4ee 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -22,7 +22,12 @@ SOFTWARE.
"""
from ml_exp.compound import Compound
import numpy as np
-import tensorflow as tf
+try:
+ import tensorflow as tf
+ TF_AV = True
+except ImportError:
+ print('Tensorflow couldn\'t be imported. Maybe it is not installed.')
+ TF_AV = False
import random
@@ -37,6 +42,10 @@ def qm7db(db_path='data',
r_seed: random seed to use for the shuffling.
use_tf: if tensorflow should be used.
"""
+ # If tf is to be used but couldn't be imported, don't try to use it.
+ if use_tf and not TF_AV:
+ use_tf = False
+
fname = f'{db_path}/hof_qm7.txt'
with open(fname, 'r') as f:
lines = f.readlines()