summaryrefslogtreecommitdiff
path: root/lj_matrix
diff options
context:
space:
mode:
Diffstat (limited to 'lj_matrix')
-rw-r--r--lj_matrix/do_ml.py36
1 files changed, 31 insertions, 5 deletions
diff --git a/lj_matrix/do_ml.py b/lj_matrix/do_ml.py
index ac044cfb3..12323780a 100644
--- a/lj_matrix/do_ml.py
+++ b/lj_matrix/do_ml.py
@@ -111,13 +111,17 @@ def ml(desc_data,
return mae, tictoc
-# Test
def do_ml(min_training_size,
max_training_size=None,
training_increment_size=None,
ljm_sigma=1.0,
ljm_epsilon=1.0,
- save_benchmarks=False):
+ save_benchmarks=False,
+ max_len=25,
+ as_eig=True,
+ bohr_radius_units=False,
+ sigma=1000.0,
+ show_msgs=True):
"""
Main function that does the whole ML process.
min_training_size: minimum training size.
@@ -126,6 +130,11 @@ def do_ml(min_training_size,
ljm_sigma: sigma value for lj matrix.
ljm_epsilon: epsilon value for lj matrix.
save_benchmarks: if benchmarks should be saved.
+ max_len: maximum amount of atoms in molecule.
+ as_eig: if data should be returned as matrix or array of eigenvalues.
+ bohr_radius_units: if units should be in bohr's radius units.
+ sigma: depth of the kernel.
+ show_msgs: Show debug messages or not.
"""
# Initialization time.
init_time = time.perf_counter()
@@ -137,7 +146,10 @@ def do_ml(min_training_size,
cm_data, ljm_data = parallel_create_matrices(molecules,
nuclear_charge,
ljm_sigma,
- ljm_epsilon)
+ ljm_epsilon,
+ max_len,
+ as_eig,
+ bohr_radius_units)
# ML calculation.
procs = []
@@ -148,14 +160,28 @@ def do_ml(min_training_size,
training_increment_size):
cm_recv, cm_send = Pipe(False)
p1 = Process(target=ml,
- args=(cm_data, energy_pbe0, i, 'CM', cm_send))
+ args=(cm_data,
+ energy_pbe0,
+ i,
+ 'CM',
+ cm_send,
+ max_training_size,
+ sigma,
+ show_msgs))
procs.append(p1)
cm_pipes.append(cm_recv)
p1.start()
ljm_recv, ljm_send = Pipe(False)
p2 = Process(target=ml,
- args=(ljm_data, energy_pbe0, i, 'L-JM', ljm_send))
+ args=(ljm_data,
+ energy_pbe0,
+ i,
+ 'L-JM',
+ ljm_send,
+ max_training_size,
+ sigma,
+ show_msgs))
procs.append(p2)
ljm_pipes.append(ljm_recv)
p2.start()