summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ml_exp/bob.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/ml_exp/bob.py b/ml_exp/bob.py
index 20085d957..86efecdb4 100644
--- a/ml_exp/bob.py
+++ b/ml_exp/bob.py
@@ -21,6 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from numpy import array, zeros
+from collections import Counter
def check_bond(bags,
@@ -87,11 +88,13 @@ def bob(c_matrix,
bags[checker[1]].append(c_matrix[i, j])
# Create the actual bond list ordered.
+ atom_counter = Counter(atoms)
atom_list = sorted(list(set(atoms)))
bonds = []
for i, a_i in enumerate(atom_list):
- for a_j in atom_list[i:]:
- bonds.append(''.join(sorted([a_i, a_j])))
+ if atom_counter[a_i] > 1:
+ for a_j in atom_list[i:]:
+ bonds.append(''.join(sorted([a_i, a_j])))
bonds = atom_list + bonds
# Create the final vector for the bob.
@@ -100,13 +103,13 @@ def bob(c_matrix,
for i, bond in enumerate(bonds):
checker = check_bond(bags, bond)
if checker[0]:
- for j, num in enumerate(bags[checker[1]][1:]):
+ for j, num in enumerate(sorted(bags[checker[1]][1:])[::-1]):
# Use c_i as the index for bob if the zero padding should
# be at the end of the vector instead of between each bond.
bob[i*max_n + j] = num
c_i += 1
# This is set to false because this was a debugging measure.
- elif False:
+ else:
print(''.join([f'Error. Bond {bond} from bond list coudn\'t',
' be found in the bags list. This could be',
' a case where the atom is only present once',