summaryrefslogtreecommitdiff
path: root/ml_exp/read_qm7_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'ml_exp/read_qm7_data.py')
-rw-r--r--ml_exp/read_qm7_data.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/ml_exp/read_qm7_data.py b/ml_exp/read_qm7_data.py
index 51c0eaee7..06de44b02 100644
--- a/ml_exp/read_qm7_data.py
+++ b/ml_exp/read_qm7_data.py
@@ -53,13 +53,15 @@ def read_nc_data(data_path):
# https://github.com/qmlcode/tutorial
def read_db_data(zi_data,
data_path,
- r_seed=111):
+ r_seed=111,
+ return_atoms=False):
"""
Reads molecule database and extracts
its contents as usable variables.
zi_data: dictionary containing nuclear charge data.
data_path: path to the data directory.
r_seed: random seed to use for the shuffling.
+ return_atoms: if atom list should be returned.
"""
os.chdir(data_path)
@@ -93,6 +95,7 @@ def read_db_data(zi_data,
mol_data = []
mol_nc_data = []
+ atoms = []
# Actual reading of the xyz files.
for i, k in enumerate(energy_temp_shuffled.keys()):
with open(k, 'r') as xyz_file:
@@ -101,15 +104,18 @@ def read_db_data(zi_data,
len_lines = len(lines)
mol_temp_data = []
mol_nc_temp_data = np.array(np.zeros(len_lines-2))
+ atoms_temp = []
for j, line in enumerate(lines[2:len_lines]):
line_list = line.split()
+ atoms_temp.append(line_list[0])
mol_nc_temp_data[j] = float(zi_data[line_list[0]])
line_data = np.array(np.asarray(line_list[1:4], dtype=float))
mol_temp_data.append(line_data)
mol_data.append(mol_temp_data)
mol_nc_data.append(mol_nc_temp_data)
+ atoms.append(atoms_temp)
# Convert everything to a numpy array.
molecules = np.array([np.array(mol) for mol in mol_data])
@@ -119,27 +125,38 @@ def read_db_data(zi_data,
energy_delta = np.array([energy_temp_shuffled[k][1]
for k in energy_temp_shuffled.keys()])
+ if return_atoms:
+ return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms
return molecules, nuclear_charge, energy_pbe0, energy_delta
-def read_qm7_data(r_seed=111):
+def read_qm7_data(data_path='data',
+ r_seed=111,
+ return_atoms=False):
"""
Reads all the qm7 data.
+ data_path: path to the data directory.
r_seed: random seed to use for the shuffling.
+ return_atoms: if atom list should be returned.
"""
tic = time.perf_counter()
printc('Data reading started.', 'CYAN')
init_path = os.getcwd()
- os.chdir('data')
+ os.chdir(data_path)
data_path = os.getcwd()
zi_data = read_nc_data(data_path)
- molecules, nuclear_charge, energy_pbe0, energy_delta = \
- read_db_data(zi_data, data_path, r_seed)
+ if return_atoms:
+ molecules, nuclear_charge, energy_pbe0, energy_delta, atoms = \
+ read_db_data(zi_data, data_path, r_seed, return_atoms)
+ else:
+ molecules, nuclear_charge, energy_pbe0, energy_delta = \
+ read_db_data(zi_data, data_path, r_seed)
os.chdir(init_path)
toc = time.perf_counter()
printc('\tData reading took {:.4f} seconds.'.format(toc-tic), 'GREEN')
-
+ if return_atoms:
+ return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms
return molecules, nuclear_charge, energy_pbe0, energy_delta