From 00301d1a9a8a7f975b64fe4ef85458f6a40776f7 Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <55825613+luevano@users.noreply.github.com>
Date: Sat, 7 Mar 2020 08:47:45 -0700
Subject: Use AM again, add tf handler

---
 ml_exp/do_ml.py   | 21 ++++++++++++++-------
 ml_exp/kernels.py | 11 ++++++++++-
 ml_exp/qm7db.py   | 11 ++++++++++-
 3 files changed, 34 insertions(+), 9 deletions(-)

diff --git a/ml_exp/do_ml.py b/ml_exp/do_ml.py
index 5ff3c1105..f4742b66d 100644
--- a/ml_exp/do_ml.py
+++ b/ml_exp/do_ml.py
@@ -23,7 +23,12 @@ SOFTWARE.
 import time
 import numpy as np
 from scipy import linalg as LA
-import tensorflow as tf
+try:
+    import tensorflow as tf
+    TF_AV = True
+except ImportError:
+    print('Tensorflow couldn\'t be imported. Maybe it is not installed.')
+    TF_AV = False
 from ml_exp.misc import printc
 from ml_exp.kernels import gaussian_kernel
 from ml_exp.qm7db import qm7db
@@ -67,6 +72,10 @@ def simple_ml(descriptors,
     if training_size >= data_size:
         raise ValueError('Training size is greater or equal to the data size.')
 
+    # If tf is to be used but couldn't be imported, don't try to use it.
+    if use_tf and not TF_AV:
+        use_tf = False
+
     # If test_size is not set, it is set to a maximum size of 1500.
     if not test_size:
         test_size = data_size - training_size
@@ -186,6 +195,10 @@ def do_ml(db_path='data',
     if type(identifiers) != list:
         raise TypeError('\'identifiers\' is not a list.')
 
+    # If tf is to be used but couldn't be imported, don't try to use it.
+    if use_tf and not TF_AV:
+        use_tf = False
+
     init_time = time.perf_counter()
 
     # Data reading.
@@ -213,13 +226,11 @@ def do_ml(db_path='data',
                              size=size,
                              as_eig=as_eig,
                              bohr_ru=bohr_ru)
-        """
         if 'AM' in identifiers:
             compound.gen_hd(size=size,
                             bohr_ru=bohr_ru)
             compound.gen_am(use_forces=use_forces,
                             size=size)
-        """
         if 'BOB' in identifiers:
             compound.gen_bob(size=size)
 
@@ -228,10 +239,8 @@ def do_ml(db_path='data',
         cm_data = np.array([comp.cm for comp in compounds], dtype=np.float64)
     if 'LJM' in identifiers:
         ljm_data = np.array([comp.ljm for comp in compounds], dtype=np.float64)
-    """
     if 'AM' in identifiers:
         am_data = np.array([comp.cm for comp in compounds], dtype=np.float64)
-    """
     if 'BOB' in identifiers:
         bob_data = np.array([comp.bob for comp in compounds], dtype=np.float64)
 
@@ -273,7 +282,6 @@ def do_ml(db_path='data',
                                         identifier='LJM',
                                         use_tf=use_tf,
                                         show_msgs=show_msgs)
-    """
     if 'AM' in identifiers:
         am_mae, am_tictoc = simple_ml(am_data,
                                       energy_pbe0,
@@ -283,7 +291,6 @@ def do_ml(db_path='data',
                                       identifier='AM',
                                       use_tf=use_tf,
                                       show_msgs=show_msgs)
-    """
     if 'BOB' in identifiers:
         bob_mae, bob_tictoc = simple_ml(bob_data,
                                         energy_pbe0,
diff --git a/ml_exp/kernels.py b/ml_exp/kernels.py
index c203af30e..abc71f7af 100644
--- a/ml_exp/kernels.py
+++ b/ml_exp/kernels.py
@@ -22,7 +22,12 @@ SOFTWARE.
 """
 # import math
 import numpy as np
-import tensorflow as tf
+try:
+    import tensorflow as tf
+    TF_AV = True
+except ImportError:
+    print('Tensorflow couldn\'t be imported. Maybe it is not installed.')
+    TF_AV = False
 
 
 def gaussian_kernel(X1,
@@ -36,6 +41,10 @@ def gaussian_kernel(X1,
     sigma: kernel width.
     use_tf: if tensorflow should be used.
     """
+    # If tf is to be used but couldn't be imported, don't try to use it.
+    if use_tf and not TF_AV:
+        use_tf = False
+
     X1_size = X1.shape[0]
     X2_size = X2.shape[0]
     i_sigma = -0.5 / (sigma*sigma)
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index c20df018e..d4734f4ee 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -22,7 +22,12 @@ SOFTWARE.
 """
 from ml_exp.compound import Compound
 import numpy as np
-import tensorflow as tf
+try:
+    import tensorflow as tf
+    TF_AV = True
+except ImportError:
+    print('Tensorflow couldn\'t be imported. Maybe it is not installed.')
+    TF_AV = False
 import random
 
 
@@ -37,6 +42,10 @@ def qm7db(db_path='data',
     r_seed: random seed to use for the shuffling.
     use_tf: if tensorflow should be used.
     """
+    # If tf is to be used but couldn't be imported, don't try to use it.
+    if use_tf and not TF_AV:
+        use_tf = False
+
     fname = f'{db_path}/hof_qm7.txt'
     with open(fname, 'r') as f:
         lines = f.readlines()
-- 
cgit v1.2.3-70-g09d2