diff options
Diffstat (limited to 'ml_exp/qm7db.py')
-rw-r--r-- | ml_exp/qm7db.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py index 3ba2c5814..29bda6a59 100644 --- a/ml_exp/qm7db.py +++ b/ml_exp/qm7db.py @@ -22,17 +22,20 @@ SOFTWARE. """ from ml_exp.compound import Compound import numpy as np +import tensorflow as tf import random def qm7db(db_path='data', is_shuffled=True, - r_seed=111): + r_seed=111, + use_tf=True): """ Creates a list of compounds with the qm7 database. db_path: path to the database directory. is_shuffled: if the resulting list of compounds should be shuffled. r_seed: random seed to use for the shuffling. + use_tf: if tensorflow should be used. """ fname = f'{db_path}/hof_qm7.txt' with open(fname, 'r') as f: @@ -52,4 +55,8 @@ def qm7db(db_path='data', e_pbe0 = np.array([comp.pbe0 for comp in compounds], dtype=np.float64) e_delta = np.array([comp.delta for comp in compounds], dtype=np.float64) + if use_tf: + e_pbe0 = tf.convert_to_tensor(e_pbe0) + e_delta = tf.convert_to_tensor(e_delta) + return compounds, e_pbe0, e_delta |