diff options
Diffstat (limited to 'ml_exp/qm7db.py')
-rw-r--r-- | ml_exp/qm7db.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py index 29bda6a59..c20df018e 100644 --- a/ml_exp/qm7db.py +++ b/ml_exp/qm7db.py @@ -56,7 +56,12 @@ def qm7db(db_path='data', e_delta = np.array([comp.delta for comp in compounds], dtype=np.float64) if use_tf: - e_pbe0 = tf.convert_to_tensor(e_pbe0) - e_delta = tf.convert_to_tensor(e_delta) + # Check if there's a gpu available and use the first one. + if tf.config.experimental.list_physical_devices('GPU'): + with tf.device('GPU:0'): + e_pbe0 = tf.convert_to_tensor(e_pbe0) + e_delta = tf.convert_to_tensor(e_delta) + else: + raise TypeError('No GPU found, could not create Tensor objects.') return compounds, e_pbe0, e_delta |