From 6ad25bb4cd5822fa84d1c4cec54952c6fd8df2f6 Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <55825613+luevano@users.noreply.github.com>
Date: Thu, 27 Feb 2020 21:41:41 -0700
Subject: Add rest of the descriptors

---
 ml_exp/do_ml.py | 95 ++++++++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 70 insertions(+), 25 deletions(-)

diff --git a/ml_exp/do_ml.py b/ml_exp/do_ml.py
index 40243d413..d8ee415bf 100644
--- a/ml_exp/do_ml.py
+++ b/ml_exp/do_ml.py
@@ -106,10 +106,14 @@ def do_ml(db_path='data',
           lj_sigma=1.0,
           lj_epsilon=1.0,
           use_forces=False,
+          stuff='bonds',
           size=23,
           as_eig=True,
           bohr_ru=False,
+          training_size=1500,
+          test_size=None,
           sigma=1000.0,
+          identifiers=["CM"],
           show_msgs=True):
     """
     Main function that does the whole ML process.
@@ -120,12 +124,20 @@ def do_ml(db_path='data',
     lj_sigma: sigma value.
     lj_epsilon: epsilon value.
     use_forces: if the use of forces instead of k_cx should be used.
+    stuff: elements of the bag, by default the known bag of bonds.
     size: compound size.
     as_eig: if the representation should be as the eigenvalues.
     bohr_ru: if radius units should be in bohr's radius units.
+    training_size: size of the training set to use.
+    test_size: size of the test set to use. If no size is given,
+        the last remaining molecules are used.
     sigma: depth of the kernel.
+    identifiers: list of names (strings) of descriptors to use.
     show_msgs: if debug messages should be shown.
     """
+    if type(identifiers) != list:
+        raise TypeError('\'identifiers\' is not a list.')
+
     init_time = time.perf_counter()
 
     # Data reading.
@@ -141,24 +153,34 @@ def do_ml(db_path='data',
     # Matrices calculation.
     tic = time.perf_counter()
     for compound in compounds:
-        compound.gen_cm(size=size,
-                        as_eig=as_eig,
-                        bohr_ru=bohr_ru)
-        compound.gen_ljm(diag_value=diag_value,
-                         sigma=lj_sigma,
-                         epsilon=lj_epsilon,
-                         size=size,
-                         as_eig=as_eig,
-                         bohr_ru=bohr_ru)
-        compound.gen_am(use_forces=use_forces,
-                        size=size,
-                        bohr_ru=bohr_ru)
+        if 'CM' in identifiers:
+            compound.gen_cm(size=size,
+                            as_eig=as_eig,
+                            bohr_ru=bohr_ru)
+        if 'LJM' in identifiers:
+            compound.gen_ljm(diag_value=diag_value,
+                             sigma=lj_sigma,
+                             epsilon=lj_epsilon,
+                             size=size,
+                             as_eig=as_eig,
+                             bohr_ru=bohr_ru)
+        if 'AM' in identifiers:
+            compound.gen_am(use_forces=use_forces,
+                            size=size,
+                            bohr_ru=bohr_ru)
+        if 'BOS' in identifiers:
+            compound.gen_bos(size=size,
+                             stuff=stuff)
 
     # Create a numpy array for the descriptors.
-    cm_data = np.array([compound.cm for compound in compounds], dtype=float)
-    ljm_data = np.array([compound.ljm for compound in compounds], dtype=float)
-    am_data = np.array([compound.cm for compound in compounds], dtype=float)
-    print(cm_data.shape, ljm_data.shape, am_data.shape)
+    if 'CM' in identifiers:
+        cm_data = np.array([comp.cm for comp in compounds], dtype=float)
+    if 'LJM' in identifiers:
+        ljm_data = np.array([comp.ljm for comp in compounds], dtype=float)
+    if 'AM' in identifiers:
+        am_data = np.array([comp.cm for comp in compounds], dtype=float)
+    if 'BOS' in identifiers:
+        bos_data = np.array([comp.bos for comp in compounds], dtype=float)
 
     toc = time.perf_counter()
     tictoc = toc - tic
@@ -166,15 +188,38 @@ def do_ml(db_path='data',
         printc(f'Matrices calculation took {tictoc:.4f} seconds.', 'CYAN')
 
     # ML calculation.
-    # CM
-    cm_mae, cm_tictoc = simple_ml(cm_data,
-                                  energy_pbe0,
-                                  training_size=5000,
-                                  test_size=1500,
-                                  sigma=1000.0,
-                                  identifier='CM',
-                                  show_msgs=show_msgs)
-    print(cm_mae, cm_tictoc)
+    if 'CM' in identifiers:
+        cm_mae, cm_tictoc = simple_ml(cm_data,
+                                      energy_pbe0,
+                                      training_size=training_size,
+                                      test_size=test_size,
+                                      sigma=sigma,
+                                      identifier='CM',
+                                      show_msgs=show_msgs)
+    if 'LJM' in identifiers:
+        ljm_mae, ljm_tictoc = simple_ml(ljm_data,
+                                        energy_pbe0,
+                                        training_size=training_size,
+                                        test_size=test_size,
+                                        sigma=sigma,
+                                        identifier='LJM',
+                                        show_msgs=show_msgs)
+    if 'AM' in identifiers:
+        am_mae, am_tictoc = simple_ml(am_data,
+                                      energy_pbe0,
+                                      training_size=training_size,
+                                      test_size=test_size,
+                                      sigma=sigma,
+                                      identifier='CM',
+                                      show_msgs=show_msgs)
+    if 'BOS' in identifiers:
+        bos_mae, bos_tictoc = simple_ml(bos_data,
+                                        energy_pbe0,
+                                        training_size=training_size,
+                                        test_size=test_size,
+                                        sigma=sigma,
+                                        identifier='CM',
+                                        show_msgs=show_msgs)
 
     # End of program
     end_time = time.perf_counter()
-- 
cgit v1.2.3-70-g09d2