summaryrefslogtreecommitdiff
path: root/ml_exp/qm7db.py
diff options
context:
space:
mode:
authorDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-03-03 22:49:31 -0700
committerDavid Luevano Alvarado <55825613+luevano@users.noreply.github.com>2020-03-03 22:49:31 -0700
commit52383ddeb87312708eeb1da765b175fb603f2802 (patch)
tree8ecb2ae66eecfa6ab4fa361167bc013c4e0b0521 /ml_exp/qm7db.py
parent1647f76052b016e4102a3af234ac47401e04819d (diff)
Possible tf addition, needs bugfixing
Diffstat (limited to 'ml_exp/qm7db.py')
-rw-r--r--ml_exp/qm7db.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index 29bda6a59..c20df018e 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -56,7 +56,12 @@ def qm7db(db_path='data',
e_delta = np.array([comp.delta for comp in compounds], dtype=np.float64)
if use_tf:
- e_pbe0 = tf.convert_to_tensor(e_pbe0)
- e_delta = tf.convert_to_tensor(e_delta)
+ # Check if there's a gpu available and use the first one.
+ if tf.config.experimental.list_physical_devices('GPU'):
+ with tf.device('GPU:0'):
+ e_pbe0 = tf.convert_to_tensor(e_pbe0)
+ e_delta = tf.convert_to_tensor(e_delta)
+ else:
+ raise TypeError('No GPU found, could not create Tensor objects.')
return compounds, e_pbe0, e_delta