From 9b26afc10f47866ddb176d9e2c8609b9b180105d Mon Sep 17 00:00:00 2001 From: David Luevano Alvarado <55825613+luevano@users.noreply.github.com> Date: Thu, 26 Mar 2020 15:58:22 -0700 Subject: Add flattening to cm and ljm --- ml_exp/representations.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ml_exp/representations.py b/ml_exp/representations.py index 3e44d626c..b0a5d8553 100644 --- a/ml_exp/representations.py +++ b/ml_exp/representations.py @@ -28,6 +28,7 @@ from ml_exp.data import POSSIBLE_BONDS def coulomb_matrix(coords, nc, size=23, + flatten=True, as_eig=True, bohr_ru=False): """ @@ -35,6 +36,7 @@ def coulomb_matrix(coords, coords: compound coordinates. nc: nuclear charge data. size: compound size. + flatten: if the representation should be 1D. as_eig: if the representation should be as the eigenvalues. bohr_ru: if radius units should be in bohr's radius units. """ @@ -75,7 +77,11 @@ size. Arrays are not of the right shape.') return np.pad(cm_eigs, (0, size - n), 'constant') else: - return np.pad(cm, ((0, size - n), (0, size - n)), 'constant') + if flatten: + return np.pad(cm, ((0, size - n), (0, size - n)), + 'constant').flatten() + else: + return np.pad(cm, ((0, size - n), (0, size - n)), 'constant') def lennard_jones_matrix(coords, @@ -84,6 +90,7 @@ def lennard_jones_matrix(coords, sigma=1.0, epsilon=1.0, size=23, + flatten=True, as_eig=True, bohr_ru=False): """ @@ -94,6 +101,7 @@ def lennard_jones_matrix(coords, sigma: sigma value. epsilon: epsilon value. size: compound size. + flatten: if the representation should be 1D. as_eig: if the representation should be as the eigenvalues. bohr_ru: if radius units should be in bohr's radius units. """ @@ -141,7 +149,11 @@ size. Arrays are not of the right shape.') return np.pad(lj_eigs, (0, size - n), 'constant') else: - return np.pad(lj, ((0, size - n), (0, size - n)), 'constant') + if flatten: + return np.pad(lj, ((0, size - n), (0, size - n)), + 'constant').flatten() + else: + return np.pad(lj, ((0, size - n), (0, size - n)), 'constant') def get_helping_data(coords, -- cgit v1.2.3-54-g00ecf