summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ml_exp/qm7db.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index e584bc123..4a7beaed6 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -27,17 +27,14 @@ import random
# 'hof_qm7.txt.txt' retrieved from
# https://github.com/qmlcode/tutorial
-def qm7db(zi_data,
+def qm7db(nc,
data_path,
- r_seed=111,
- return_atoms=False):
+ r_seed=111):
"""
- Reads molecule database and extracts
- its contents as usable variables.
- zi_data: dictionary containing nuclear charge data.
+ Creates a list of compounds with the qm7 database.
+ nc: dictionary containing nuclear charge data.
data_path: path to the data directory.
r_seed: random seed to use for the shuffling.
- return_atoms: if atom list should be returned.
"""
os.chdir(data_path)
@@ -85,7 +82,7 @@ def qm7db(zi_data,
line_list = line.split()
atoms_temp.append(line_list[0])
- mol_nc_temp_data[j] = float(zi_data[line_list[0]])
+ mol_nc_temp_data[j] = float(nc[line_list[0]])
line_data = np.array(np.asarray(line_list[1:4], dtype=float))
mol_temp_data.append(line_data)
@@ -101,6 +98,4 @@ def qm7db(zi_data,
energy_delta = np.array([energy_temp_shuffled[k][1]
for k in energy_temp_shuffled.keys()])
- if return_atoms:
- return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms
- return molecules, nuclear_charge, energy_pbe0, energy_delta
+ return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms