diff options
-rw-r--r-- | ml_exp/bob.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/ml_exp/bob.py b/ml_exp/bob.py index 20085d957..86efecdb4 100644 --- a/ml_exp/bob.py +++ b/ml_exp/bob.py @@ -21,6 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from numpy import array, zeros +from collections import Counter def check_bond(bags, @@ -87,11 +88,13 @@ def bob(c_matrix, bags[checker[1]].append(c_matrix[i, j]) # Create the actual bond list ordered. + atom_counter = Counter(atoms) atom_list = sorted(list(set(atoms))) bonds = [] for i, a_i in enumerate(atom_list): - for a_j in atom_list[i:]: - bonds.append(''.join(sorted([a_i, a_j]))) + if atom_counter[a_i] > 1: + for a_j in atom_list[i:]: + bonds.append(''.join(sorted([a_i, a_j]))) bonds = atom_list + bonds # Create the final vector for the bob. @@ -100,13 +103,13 @@ def bob(c_matrix, for i, bond in enumerate(bonds): checker = check_bond(bags, bond) if checker[0]: - for j, num in enumerate(bags[checker[1]][1:]): + for j, num in enumerate(sorted(bags[checker[1]][1:])[::-1]): # Use c_i as the index for bob if the zero padding should # be at the end of the vector instead of between each bond. bob[i*max_n + j] = num c_i += 1 # This is set to false because this was a debugging measure. - elif False: + else: print(''.join([f'Error. Bond {bond} from bond list coudn\'t', ' be found in the bags list. This could be', ' a case where the atom is only present once', |