summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Luevano <55825613+luevano@users.noreply.github.com>2019-12-28 11:37:22 -0700
committerDavid Luevano <55825613+luevano@users.noreply.github.com>2019-12-28 11:37:22 -0700
commit4704314c9b4d1066383da5c3d6ca87bba9067c8d (patch)
tree47efddb979957029945a473fde6ed2cde2c2b196
parente4f9e15588ec796f73c000a683cc9152454a913c (diff)
Refactor code
-rw-r--r--lj_matrix/__main__.py1
-rw-r--r--lj_matrix/do_ml.py5
-rw-r--r--lj_matrix/lj_matrix.py2
-rw-r--r--lj_matrix/read_qm7_data.py7
4 files changed, 10 insertions, 5 deletions
diff --git a/lj_matrix/__main__.py b/lj_matrix/__main__.py
index 811024ff0..688e5adcc 100644
--- a/lj_matrix/__main__.py
+++ b/lj_matrix/__main__.py
@@ -31,6 +31,7 @@ if __name__ == '__main__':
ljm_diag_value=None,
ljm_sigma=1.0,
ljm_epsilon=1.0,
+ r_seed=111,
save_benchmarks=False,
show_msgs=True)
# plot_benchmarks()
diff --git a/lj_matrix/do_ml.py b/lj_matrix/do_ml.py
index da9386bf7..25a55e823 100644
--- a/lj_matrix/do_ml.py
+++ b/lj_matrix/do_ml.py
@@ -118,6 +118,7 @@ def do_ml(min_training_size,
ljm_diag_value=None,
ljm_sigma=1.0,
ljm_epsilon=1.0,
+ r_seed=111,
save_benchmarks=False,
max_len=25,
as_eig=True,
@@ -134,6 +135,7 @@ def do_ml(min_training_size,
ljm_diag_value: if a special diagonal value should be used in lj matrix.
ljm_sigma: sigma value for lj matrix.
ljm_epsilon: epsilon value for lj matrix.
+ r_seed: random seed to use for the shuffling.
save_benchmarks: if benchmarks should be saved.
max_len: maximum amount of atoms in molecule.
as_eig: if data should be returned as matrix or array of eigenvalues.
@@ -147,7 +149,8 @@ def do_ml(min_training_size,
max_training_size = min_training_size + training_increment_size
# Data reading.
- molecules, nuclear_charge, energy_pbe0, energy_delta = read_qm7_data()
+ molecules, nuclear_charge, energy_pbe0, energy_delta =\
+ read_qm7_data(r_seed)
# Matrices calculation.
cm_data, ljm_data = parallel_create_matrices(molecules,
diff --git a/lj_matrix/lj_matrix.py b/lj_matrix/lj_matrix.py
index c3b61becb..6739ae283 100644
--- a/lj_matrix/lj_matrix.py
+++ b/lj_matrix/lj_matrix.py
@@ -88,7 +88,7 @@ def lj_matrix(mol_data,
z = (z_i-z_j)**2
if i == j:
- if not diag_value:
+ if diag_value is None:
lj[i, j] = (0.5*Z_i**2.4)
else:
lj[i, j] = diag_value
diff --git a/lj_matrix/read_qm7_data.py b/lj_matrix/read_qm7_data.py
index 9bb7629ca..4401ca1c0 100644
--- a/lj_matrix/read_qm7_data.py
+++ b/lj_matrix/read_qm7_data.py
@@ -59,7 +59,7 @@ def read_db_data(zi_data,
its contents as usable variables.
zi_data: dictionary containing nuclear charge data.
data_path: path to the data directory.
- r_seed: random seed.
+ r_seed: random seed to use for the shuffling.
"""
os.chdir(data_path)
@@ -122,9 +122,10 @@ def read_db_data(zi_data,
return molecules, nuclear_charge, energy_pbe0, energy_delta
-def read_qm7_data():
+def read_qm7_data(r_seed=111):
"""
Reads all the qm7 data.
+ r_seed: random seed to use for the shuffling.
"""
tic = time.perf_counter()
printc('Data reading started.', 'CYAN')
@@ -135,7 +136,7 @@ def read_qm7_data():
zi_data = read_nc_data(data_path)
molecules, nuclear_charge, energy_pbe0, energy_delta = \
- read_db_data(zi_data, data_path)
+ read_db_data(zi_data, data_path, r_seed)
os.chdir(init_path)
toc = time.perf_counter()