diff options
Diffstat (limited to 'ml_exp/read_qm7_data.py')
-rw-r--r-- | ml_exp/read_qm7_data.py | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/ml_exp/read_qm7_data.py b/ml_exp/read_qm7_data.py index 51c0eaee7..06de44b02 100644 --- a/ml_exp/read_qm7_data.py +++ b/ml_exp/read_qm7_data.py @@ -53,13 +53,15 @@ def read_nc_data(data_path): # https://github.com/qmlcode/tutorial def read_db_data(zi_data, data_path, - r_seed=111): + r_seed=111, + return_atoms=False): """ Reads molecule database and extracts its contents as usable variables. zi_data: dictionary containing nuclear charge data. data_path: path to the data directory. r_seed: random seed to use for the shuffling. + return_atoms: if atom list should be returned. """ os.chdir(data_path) @@ -93,6 +95,7 @@ def read_db_data(zi_data, mol_data = [] mol_nc_data = [] + atoms = [] # Actual reading of the xyz files. for i, k in enumerate(energy_temp_shuffled.keys()): with open(k, 'r') as xyz_file: @@ -101,15 +104,18 @@ def read_db_data(zi_data, len_lines = len(lines) mol_temp_data = [] mol_nc_temp_data = np.array(np.zeros(len_lines-2)) + atoms_temp = [] for j, line in enumerate(lines[2:len_lines]): line_list = line.split() + atoms_temp.append(line_list[0]) mol_nc_temp_data[j] = float(zi_data[line_list[0]]) line_data = np.array(np.asarray(line_list[1:4], dtype=float)) mol_temp_data.append(line_data) mol_data.append(mol_temp_data) mol_nc_data.append(mol_nc_temp_data) + atoms.append(atoms_temp) # Convert everything to a numpy array. molecules = np.array([np.array(mol) for mol in mol_data]) @@ -119,27 +125,38 @@ def read_db_data(zi_data, energy_delta = np.array([energy_temp_shuffled[k][1] for k in energy_temp_shuffled.keys()]) + if return_atoms: + return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms return molecules, nuclear_charge, energy_pbe0, energy_delta -def read_qm7_data(r_seed=111): +def read_qm7_data(data_path='data', + r_seed=111, + return_atoms=False): """ Reads all the qm7 data. + data_path: path to the data directory. r_seed: random seed to use for the shuffling. + return_atoms: if atom list should be returned. """ tic = time.perf_counter() printc('Data reading started.', 'CYAN') init_path = os.getcwd() - os.chdir('data') + os.chdir(data_path) data_path = os.getcwd() zi_data = read_nc_data(data_path) - molecules, nuclear_charge, energy_pbe0, energy_delta = \ - read_db_data(zi_data, data_path, r_seed) + if return_atoms: + molecules, nuclear_charge, energy_pbe0, energy_delta, atoms = \ + read_db_data(zi_data, data_path, r_seed, return_atoms) + else: + molecules, nuclear_charge, energy_pbe0, energy_delta = \ + read_db_data(zi_data, data_path, r_seed) os.chdir(init_path) toc = time.perf_counter() printc('\tData reading took {:.4f} seconds.'.format(toc-tic), 'GREEN') - + if return_atoms: + return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms return molecules, nuclear_charge, energy_pbe0, energy_delta |