summaryrefslogtreecommitdiff
path: root/ml_exp/qm7db.py
diff options
context:
space:
mode:
Diffstat (limited to 'ml_exp/qm7db.py')
-rw-r--r--ml_exp/qm7db.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index d4734f4ee..1e78b8d56 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -53,16 +53,16 @@ def qm7db(db_path='data',
compounds = []
for i, line in enumerate(lines):
line = line.split()
- compounds.append(Compound(f'{db_path}/{line[0]}'))
- compounds[i].pbe0 = np.float64(line[1])
- compounds[i].delta = np.float64(line[1]) - np.float64(line[2])
+ compounds.append(Compound(f'{db_path}/{line[0]}', db='qm7'))
+ compounds[i].qm7pbe0 = np.float64(line[1])
+ compounds[i].qm7delta = np.float64(line[1]) - np.float64(line[2])
if is_shuffled:
random.seed(r_seed)
random.shuffle(compounds)
- e_pbe0 = np.array([comp.pbe0 for comp in compounds], dtype=np.float64)
- e_delta = np.array([comp.delta for comp in compounds], dtype=np.float64)
+ e_pbe0 = np.array([comp.qm7pbe0 for comp in compounds], dtype=np.float64)
+ e_delta = np.array([comp.qm7delta for comp in compounds], dtype=np.float64)
if use_tf:
# Check if there's a gpu available and use the first one.