From 8dd406df884d4eb1b07a7aeb7c2a0d6d1adfe5ee Mon Sep 17 00:00:00 2001
From: David Luevano <55825613+luevano@users.noreply.github.com>
Date: Thu, 23 Jan 2020 20:16:26 -0700
Subject: Add adjacency matrix

---
 ml_exp/__main__.py      | 15 ++++++++-
 ml_exp/adj_matrix.py    | 90 +++++++++++++++++++++++++++++++++++++++++++++++++
 ml_exp/read_qm7_data.py | 29 ++++++++++++----
 3 files changed, 127 insertions(+), 7 deletions(-)
 create mode 100644 ml_exp/adj_matrix.py

diff --git a/ml_exp/__main__.py b/ml_exp/__main__.py
index 167f0fa7b..63708e75c 100644
--- a/ml_exp/__main__.py
+++ b/ml_exp/__main__.py
@@ -20,10 +20,13 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 SOFTWARE.
 """
-from ml_exp.do_ml import do_ml
+# from ml_exp.do_ml import do_ml
 # from ml_exp.misc import plot_benchmarks
+from ml_exp.read_qm7_data import read_qm7_data
+from ml_exp.adj_matrix import fneig_matrix, adj_matrix
 
 if __name__ == '__main__':
+    """
     do_ml(min_training_size=1500,
           max_training_size=2000,
           training_increment_size=500,
@@ -34,5 +37,15 @@ if __name__ == '__main__':
           r_seed=111,
           save_benchmarks=False,
           show_msgs=True)
+    """
     # plot_benchmarks()
+    xyz, nc, pbe0, delta, atoms = read_qm7_data(return_atoms=True)
+    for i in range(1):
+        fnm, bonds = fneig_matrix(atoms[i], xyz[i])
+        am = adj_matrix(bonds)
+
+        print(f'{i} first neighbor matrix\n{fnm}')
+        print(f'{i} bond list\n{bonds}')
+        print(f'{i} adjacency matrix\n{am}')
+        print('-'*30)
     print('OK!')
diff --git a/ml_exp/adj_matrix.py b/ml_exp/adj_matrix.py
new file mode 100644
index 000000000..1fbbe0555
--- /dev/null
+++ b/ml_exp/adj_matrix.py
@@ -0,0 +1,90 @@
+"""MIT License
+
+Copyright (c) 2019 David Luevano Alvarado
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+from numpy import array, zeros
+from numpy.linalg import norm
+
+
+def fneig_matrix(atoms,
+                 xyz):
+    """
+    Creates the first neighbor matrix of the given molecule data.
+    atoms: list of atoms.
+    xyz: matrix of atomic coords.
+    NOTE: Bond distance of carbon to other elements
+        are (for atoms present in the qm7 dataset):
+            H: 1.06 - 1.12 A
+            O: 1.43 - 2.15 A
+            N: 1.47 - 2.10 A
+            S: 1.81 - 2.55 A
+    """
+    # Possible bonds.
+    ch_bond = sorted(['C', 'H'])
+    co_bond = sorted(['C', 'O'])
+    cn_bond = sorted(['C', 'N'])
+    cs_bond = sorted(['C', 'S'])
+
+    # Number of atoms, empty matrix and bond list.
+    n = len(atoms)
+    fnm = array(zeros((n, n)))
+    bonds = []
+    for i, xyz_i in enumerate(xyz):
+        for j, xyz_j in enumerate(xyz):
+            # Ignore the diagonal.
+            if i != j:
+                bond = sorted([atoms[i], atoms[j]])
+                r = norm(xyz_i - xyz_j)
+                # Check for each type of bond.
+                if (ch_bond == bond) and (r >= 1.06 and r <= 1.12):
+                    fnm[i, j] = 1
+                    if j > i:
+                        bonds.append((i, j))
+                elif (co_bond == bond) and (r >= 1.43 and r <= 2.15):
+                    fnm[i, j] = 1
+                    if j > i:
+                        bonds.append((i, j))
+                elif (cn_bond == bond) and (r >= 1.47 and r <= 2.10):
+                    fnm[i, j] = 1
+                    if j > i:
+                        bonds.append((i, j))
+                elif (cs_bond == bond) and (r >= 1.81 and r <= 2.55):
+                    fnm[i, j] = 1
+                    if j > i:
+                        bonds.append((i, j))
+    return fnm, bonds
+
+
+def adj_matrix(bonds):
+    """
+    Calculates the adjacency matrix given the bond list.
+    bonds: list of bonds (tuple of indexes).
+    """
+    n = len(bonds)
+    am = array(zeros((n, n)))
+    for i, bond_i in enumerate(bonds):
+        for j, bond_j in enumerate(bonds):
+            # Ignore the diagonal.
+            if i != j:
+                if (bond_i[0] in bond_j) or (bond_i[1] in bond_j):
+                    am[i, j] = 1
+
+    return am
diff --git a/ml_exp/read_qm7_data.py b/ml_exp/read_qm7_data.py
index 51c0eaee7..06de44b02 100644
--- a/ml_exp/read_qm7_data.py
+++ b/ml_exp/read_qm7_data.py
@@ -53,13 +53,15 @@ def read_nc_data(data_path):
 # https://github.com/qmlcode/tutorial
 def read_db_data(zi_data,
                  data_path,
-                 r_seed=111):
+                 r_seed=111,
+                 return_atoms=False):
     """
     Reads molecule database and extracts
     its contents as usable variables.
     zi_data: dictionary containing nuclear charge data.
     data_path: path to the data directory.
     r_seed: random seed to use for the shuffling.
+    return_atoms: if atom list should be returned.
     """
     os.chdir(data_path)
 
@@ -93,6 +95,7 @@ def read_db_data(zi_data,
 
     mol_data = []
     mol_nc_data = []
+    atoms = []
     # Actual reading of the xyz files.
     for i, k in enumerate(energy_temp_shuffled.keys()):
         with open(k, 'r') as xyz_file:
@@ -101,15 +104,18 @@ def read_db_data(zi_data,
         len_lines = len(lines)
         mol_temp_data = []
         mol_nc_temp_data = np.array(np.zeros(len_lines-2))
+        atoms_temp = []
         for j, line in enumerate(lines[2:len_lines]):
             line_list = line.split()
 
+            atoms_temp.append(line_list[0])
             mol_nc_temp_data[j] = float(zi_data[line_list[0]])
             line_data = np.array(np.asarray(line_list[1:4], dtype=float))
             mol_temp_data.append(line_data)
 
         mol_data.append(mol_temp_data)
         mol_nc_data.append(mol_nc_temp_data)
+        atoms.append(atoms_temp)
 
     # Convert everything to a numpy array.
     molecules = np.array([np.array(mol) for mol in mol_data])
@@ -119,27 +125,38 @@ def read_db_data(zi_data,
     energy_delta = np.array([energy_temp_shuffled[k][1]
                              for k in energy_temp_shuffled.keys()])
 
+    if return_atoms:
+        return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms
     return molecules, nuclear_charge, energy_pbe0, energy_delta
 
 
-def read_qm7_data(r_seed=111):
+def read_qm7_data(data_path='data',
+                  r_seed=111,
+                  return_atoms=False):
     """
     Reads all the qm7 data.
+    data_path: path to the data directory.
     r_seed: random seed to use for the shuffling.
+    return_atoms: if atom list should be returned.
     """
     tic = time.perf_counter()
     printc('Data reading started.', 'CYAN')
 
     init_path = os.getcwd()
-    os.chdir('data')
+    os.chdir(data_path)
     data_path = os.getcwd()
 
     zi_data = read_nc_data(data_path)
-    molecules, nuclear_charge, energy_pbe0, energy_delta = \
-        read_db_data(zi_data, data_path, r_seed)
+    if return_atoms:
+        molecules, nuclear_charge, energy_pbe0, energy_delta, atoms = \
+            read_db_data(zi_data, data_path, r_seed, return_atoms)
+    else:
+        molecules, nuclear_charge, energy_pbe0, energy_delta = \
+            read_db_data(zi_data, data_path, r_seed)
 
     os.chdir(init_path)
     toc = time.perf_counter()
     printc('\tData reading took {:.4f} seconds.'.format(toc-tic), 'GREEN')
-
+    if return_atoms:
+        return molecules, nuclear_charge, energy_pbe0, energy_delta, atoms
     return molecules, nuclear_charge, energy_pbe0, energy_delta
-- 
cgit v1.2.3-70-g09d2