From f0ea2db1bfc21f35aa37de1ec0b0da497f5e30eb Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <55825613+luevano@users.noreply.github.com>
Date: Tue, 10 Mar 2020 14:50:26 -0700
Subject: Prepare for qm9 db

---
 ml_exp/compound.py | 31 ++++++++++++++++++++++---------
 ml_exp/qm7db.py    | 10 +++++-----
 2 files changed, 27 insertions(+), 14 deletions(-)

diff --git a/ml_exp/compound.py b/ml_exp/compound.py
index eae280d38..6d9790c3b 100644
--- a/ml_exp/compound.py
+++ b/ml_exp/compound.py
@@ -28,20 +28,31 @@ from ml_exp.representations import coulomb_matrix, lennard_jones_matrix,\
 
 class Compound:
     def __init__(self,
-                 xyz=None):
+                 xyz=None,
+                 db='qm7'):
         """
         Initialization of the Compound.
         xyz: (path to) the xyz file.
+        db: which db is the xyz file based on.
         """
-        # General compound data.
+        # xyz and nc data.
         self.name = None
         self.n = None
-        self.extra = None
+        self.comment = None
+        self.coordinates = None
         self.atoms = None
         self.nc = None
-        self.coordinates = None
-        self.pbe0 = None
-        self.delta = None
+
+        # qm7 data.
+        self.qm7pbe0 = None
+        self.qm7delta = None
+
+        # qm9 data.
+        self.qm9prop = None
+        self.qm9Mulliken = None
+        self.qm9frec = None
+        self.qm9SMILES = None
+        self.qm9InChI = None
 
         # Computed data.
         self.cm = None
@@ -61,7 +72,7 @@ class Compound:
         self.bonds_f = None
 
         if xyz is not None:
-            self.read_xyz(xyz)
+            self.read_xyz(xyz, db=db)
 
     def gen_cm(self,
                size=23,
@@ -155,17 +166,19 @@ class Compound:
                                 size=size)
 
     def read_xyz(self,
-                 filename):
+                 filename,
+                 db='qm7'):
         """
         Reads an xyz file and adds the corresponding data to the Compound.
         filename: (path to) the xyz file.
+        db: which db is the xyz file based on.
         """
         with open(filename, 'r') as f:
             lines = f.readlines()
 
         self.name = filename.split('/')[-1]
         self.n = np.int32(lines[0])
-        self.extra = lines[1]
+        self.comment = lines[1]
         self.atoms = []
         self.nc = np.empty(self.n, dtype=np.int64)
         self.coordinates = np.empty((self.n, 3), dtype=np.float64)
diff --git a/ml_exp/qm7db.py b/ml_exp/qm7db.py
index d4734f4ee..1e78b8d56 100644
--- a/ml_exp/qm7db.py
+++ b/ml_exp/qm7db.py
@@ -53,16 +53,16 @@ def qm7db(db_path='data',
     compounds = []
     for i, line in enumerate(lines):
         line = line.split()
-        compounds.append(Compound(f'{db_path}/{line[0]}'))
-        compounds[i].pbe0 = np.float64(line[1])
-        compounds[i].delta = np.float64(line[1]) - np.float64(line[2])
+        compounds.append(Compound(f'{db_path}/{line[0]}', db='qm7'))
+        compounds[i].qm7pbe0 = np.float64(line[1])
+        compounds[i].qm7delta = np.float64(line[1]) - np.float64(line[2])
 
     if is_shuffled:
         random.seed(r_seed)
         random.shuffle(compounds)
 
-    e_pbe0 = np.array([comp.pbe0 for comp in compounds], dtype=np.float64)
-    e_delta = np.array([comp.delta for comp in compounds], dtype=np.float64)
+    e_pbe0 = np.array([comp.qm7pbe0 for comp in compounds], dtype=np.float64)
+    e_delta = np.array([comp.qm7delta for comp in compounds], dtype=np.float64)
 
     if use_tf:
         # Check if there's a gpu available and use the first one.
-- 
cgit v1.2.3-70-g09d2