summaryrefslogtreecommitdiff
path: root/ml_exp/qm7db.py
diff options
context:
space:
mode:
Diffstat (limited to 'ml_exp/qm7db.py')
-rw-r--r--ml_exp/qm7db.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index 29bda6a59..c20df018e 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -56,7 +56,12 @@ def qm7db(db_path='data',
e_delta = np.array([comp.delta for comp in compounds], dtype=np.float64)
if use_tf:
- e_pbe0 = tf.convert_to_tensor(e_pbe0)
- e_delta = tf.convert_to_tensor(e_delta)
+ # Check if there's a gpu available and use the first one.
+ if tf.config.experimental.list_physical_devices('GPU'):
+ with tf.device('GPU:0'):
+ e_pbe0 = tf.convert_to_tensor(e_pbe0)
+ e_delta = tf.convert_to_tensor(e_delta)
+ else:
+ raise TypeError('No GPU found, could not create Tensor objects.')
return compounds, e_pbe0, e_delta