summaryrefslogtreecommitdiff
path: root/ml_exp/qm7db.py
diff options
context:
space:
mode:
Diffstat (limited to 'ml_exp/qm7db.py')
-rw-r--r--ml_exp/qm7db.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index 3ba2c5814..29bda6a59 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -22,17 +22,20 @@ SOFTWARE.
"""
from ml_exp.compound import Compound
import numpy as np
+import tensorflow as tf
import random
def qm7db(db_path='data',
is_shuffled=True,
- r_seed=111):
+ r_seed=111,
+ use_tf=True):
"""
Creates a list of compounds with the qm7 database.
db_path: path to the database directory.
is_shuffled: if the resulting list of compounds should be shuffled.
r_seed: random seed to use for the shuffling.
+ use_tf: if tensorflow should be used.
"""
fname = f'{db_path}/hof_qm7.txt'
with open(fname, 'r') as f:
@@ -52,4 +55,8 @@ def qm7db(db_path='data',
e_pbe0 = np.array([comp.pbe0 for comp in compounds], dtype=np.float64)
e_delta = np.array([comp.delta for comp in compounds], dtype=np.float64)
+ if use_tf:
+ e_pbe0 = tf.convert_to_tensor(e_pbe0)
+ e_delta = tf.convert_to_tensor(e_delta)
+
return compounds, e_pbe0, e_delta