From 5a85580b5f99484b957b65326ec5a19ab627d801 Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <55825613+luevano@users.noreply.github.com>
Date: Wed, 26 Feb 2020 05:26:21 -0700
Subject: Optimize qm7db reading for the rewrite

---
 ml_exp/__init__.py |  2 --
 ml_exp/compound.py |  3 ++-
 ml_exp/qm7db.py    | 68 +++++++++++-------------------------------------------
 3 files changed, 16 insertions(+), 57 deletions(-)

diff --git a/ml_exp/__init__.py b/ml_exp/__init__.py
index e9dd83eae..d685e7ccc 100644
--- a/ml_exp/__init__.py
+++ b/ml_exp/__init__.py
@@ -25,8 +25,6 @@ from ml_exp.representations import coulomb_matrix, lennard_jones_matrix,\
     first_neighbor_matrix, adjacency_matrix, check_bond, bag_of_stuff
 from ml_exp.math import cholesky_solve
 
-# If somebody does "from package import *", this is what they will
-# be able to access:
 __all__ = ['Compound',
            'coulomb_matrix',
            'lennard_jones_matrix',
diff --git a/ml_exp/compound.py b/ml_exp/compound.py
index 8b6af0ae9..0a8b89610 100644
--- a/ml_exp/compound.py
+++ b/ml_exp/compound.py
@@ -33,7 +33,7 @@ class Compound:
         Initialization of the Compound.
         xyz: (path to) the xyz file.
         """
-        # empty_array = np.asarray([], dtype=float)
+        self.name = None
 
         self.n = None
         self.atoms = None
@@ -137,6 +137,7 @@ class Compound:
         with open(filename, 'r') as f:
             lines = f.readlines()
 
+        self.name = filename.split('/')[-1]
         self.n = int(lines[0])
         self.atoms = []
         self.atoms_nc = np.empty(self.n, dtype=int)
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index 1f1115ba0..f9950c317 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -20,6 +20,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 SOFTWARE.
 """
+from ml_exp.compound import Compound
 import numpy as np
 import random
 
@@ -28,72 +29,31 @@ import random
 # https://github.com/qmlcode/tutorial
 def qm7db(nc,
           db_path='data',
+          is_shuffled=True,
           r_seed=111):
     """
     Creates a list of compounds with the qm7 database.
     nc: dictionary containing nuclear charge data.
     db_path: path to the database directory.
+    is_shuffled: if the resulting list of compounds should be shuffled.
     r_seed: random seed to use for the shuffling.
     """
-
     fname = f'{db_path}/hof_qm7.txt'
     with open(fname, 'r') as f:
         lines = f.readlines()
 
-    # Temporary energy dictionary.
-    energy_temp = dict()
-
-    for line in lines:
-        xyz_data = line.split()
-
-        xyz_name = xyz_data[0]
-        hof = float(xyz_data[1])
-        dftb = float(xyz_data[2])
-        # print(xyz_name, hof, dftb)
+    compounds = []
+    for i, line in enumerate(lines):
+        line = line.split()
+        compounds.append(Compound(f'{db_path}/{line[0]}'))
+        compounds[i].pbe0 = float(line[1])
+        compounds[i].delta = float(line[1]) - float(line[2])
 
-        energy_temp[xyz_name] = np.array([hof, hof - dftb])
-
-    # Use a random seed.
+    # Shuffle the compounds list
     random.seed(r_seed)
+    random.shuffle(compounds)
 
-    et_keys = list(energy_temp.keys())
-    random.shuffle(et_keys)
-
-    # Temporary energy dictionary, shuffled.
-    energy_temp_shuffled = dict()
-    for key in et_keys:
-        energy_temp_shuffled.update({key: energy_temp[key]})
-
-    mol_data = []
-    mol_nc_data = []
-    atoms = []
-    # Actual reading of the xyz files.
-    for i, k in enumerate(energy_temp_shuffled.keys()):
-        with open(k, 'r') as xyz_file:
-            lines = xyz_file.readlines()
-
-        len_lines = len(lines)
-        mol_temp_data = []
-        mol_nc_temp_data = np.array(np.zeros(len_lines-2))
-        atoms_temp = []
-        for j, line in enumerate(lines[2:len_lines]):
-            line_list = line.split()
-
-            atoms_temp.append(line_list[0])
-            mol_nc_temp_data[j] = float(nc[line_list[0]])
-            line_data = np.array(np.asarray(line_list[1:4], dtype=float))
-            mol_temp_data.append(line_data)
-
-        mol_data.append(mol_temp_data)
-        mol_nc_data.append(mol_nc_temp_data)
-        atoms.append(atoms_temp)
-
-    # Convert everything to a numpy array.
-    molecules = np.array([np.array(mol) for mol in mol_data])
-    nuclear_charge = np.array([nc_d for nc_d in mol_nc_data])
-    energy_pbe0 = np.array([energy_temp_shuffled[k][0]
-                            for k in energy_temp_shuffled.keys()])
-    energy_delta = np.array([energy_temp_shuffled[k][1]
-                             for k in energy_temp_shuffled.keys()])
+    e_pbe0 = np.array([compound.pbe0 for compound in compounds], dtype=float)
+    e_delta = np.array([compound.delta for compound in compounds], dtype=float)
 
-    return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms
+    return compounds, e_pbe0, e_delta
-- 
cgit v1.2.3-70-g09d2