From eae1a3a5f602e9c79e07f4b6c1b133dfd12e1d5c Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <david@luevano.xyz>
Date: Fri, 24 Feb 2023 21:37:37 -0600
Subject: add database_entry tests, change type for tags attr

---
 setup.cfg                    |  1 +
 src/pyssg/database.py        | 31 ++++++++++-------
 src/pyssg/database_entry.py  | 38 ++++++++++++++-------
 tests/test_database_entry.py | 81 ++++++++++++++++++++++++++++++++++++++++++++
 tests/test_utils.py          |  2 +-
 5 files changed, 127 insertions(+), 26 deletions(-)
 create mode 100644 tests/test_database_entry.py

diff --git a/setup.cfg b/setup.cfg
index d61839e..1fb6efd 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -62,6 +62,7 @@ per-file-ignores =
 	__init__.py: W292
 	arg_parser.py: E501
 	custom_logger.py: E501
+	test_database_entry.py: E501
 
 [pbr]
 skip_authors = True
diff --git a/src/pyssg/database.py b/src/pyssg/database.py
index 5d7d71d..ae0a8d4 100644
--- a/src/pyssg/database.py
+++ b/src/pyssg/database.py
@@ -20,8 +20,9 @@ class Database:
         self.e: dict[str, DatabaseEntry] = dict()
 
     def update_tags(self, file_name: str,
-                    new_tags: list[str]) -> None:
-        # technically, I should ensure this function can only run if self.e is populated
+                    new_tags: set[str]) -> None:
+        # technically, I should ensure this function can only run
+        #   if self.e is populated
         if file_name in self.e:
             log.debug('updating tags for entry "%s"', file_name)
             log.debug('entry "%s" old tags: %s',
@@ -39,7 +40,7 @@ class Database:
                remove: str = '') -> None:
         log.debug('updating entry for file "%s"', file_name)
         f: str = file_name
-        tags: list[str] = []
+        tags: set[str] = set()
         if remove != '':
             f = file_name.replace(remove, '')
             log.debug('removed "%s" from "%s": "%s"', remove, file_name, f)
@@ -49,8 +50,8 @@ class Database:
         log.debug('time for "%s": %s', file_name, time)
 
         # calculate current checksum, also needs actual file name
-        checksum: str = get_checksum(file_name)
-        log.debug('checksum for "%s": "%s"', file_name, checksum)
+        cksm: str = get_checksum(file_name)
+        log.debug('checksum for "%s": "%s"', file_name, cksm)
 
         # three cases, 1) entry didn't exist,
         # 2) entry has been mod and,
@@ -58,7 +59,7 @@ class Database:
         # 1)
         if f not in self.e:
             log.debug('entry "%s" didn\'t exist, adding with defaults', f)
-            self.e[f] = DatabaseEntry([f, time, 0.0, checksum, tags])
+            self.e[f] = DatabaseEntry((f, time, 0.0, cksm, tags))
             return
 
         # old_e is old entity
@@ -66,10 +67,10 @@ class Database:
         log.debug('entry "%s" old content: %s', f, old_e)
 
         # 2)
-        if checksum != old_e.checksum:
+        if cksm != old_e.checksum:
             log.debug('entry "%s" has been modified, updating', f)
-            self.e[f] = DatabaseEntry([f, old_e.ctimestamp, time, checksum, tags])
-            log.debug('entry "%s" new content: (%s, %s, %s, (%s))', f, self.e[f])
+            self.e[f] = DatabaseEntry((f, old_e.ctimestamp, time, cksm, tags))
+            log.debug('entry "%s" new content: %s', f, self.e[f])
             return
         # 3)
         else:
@@ -96,7 +97,7 @@ class Database:
             sys.exit(1)
         return True
 
-    def _get_csv_rows(self) -> list[list[str]]:
+    def _get_raw_csv_rows(self) -> list[list[str]]:
         rows: list[list[str]]
         with open(self.db_path, 'r') as f:
             csv_reader = csv.reader(f, delimiter=self.__COLUMN_DELIMITER)
@@ -110,7 +111,7 @@ class Database:
         if not self._db_path_exists():
             return
 
-        rows: list[list[str]] = self._get_csv_rows()
+        rows: list[list[str]] = self._get_raw_csv_rows()
         # l=list of values in entry
         log.debug('parsing rows from db')
         for it, row in enumerate(rows):
@@ -122,5 +123,11 @@ class Database:
                              ' columns: "%s"',
                              i, self.__COLUMN_NUM, col_num, row)
                 sys.exit(1)
-            entry: DatabaseEntry = DatabaseEntry(row)
+            # actual value types
+            r: tuple[str, float, float, str, str] = (str(row[0]),
+                                                     float(row[1]),
+                                                     float(row[2]),
+                                                     str(row[3]),
+                                                     str(row[4]))
+            entry: DatabaseEntry = DatabaseEntry(r)
             self.e[entry.fname] = entry
diff --git a/src/pyssg/database_entry.py b/src/pyssg/database_entry.py
index 5c3e659..58e9884 100644
--- a/src/pyssg/database_entry.py
+++ b/src/pyssg/database_entry.py
@@ -1,42 +1,54 @@
+import sys
 from logging import Logger, getLogger
 
 log: Logger = getLogger(__name__)
 
 
 class DatabaseEntry:
-    # not specifying the type of "list" as it could be only str
-    #   or the actual values
-    def __init__(self, entry: list) -> None:
-        self.fname: str = entry[0]
+    # ignoring return type as it makes the line too long, unnecessary, too
+    def __init__(self, entry: tuple[str, float, float, str, str | set[str]]):
+        self.fname: str = str(entry[0])
         self.ctimestamp: float = float(entry[1])
         self.mtimestamp: float = float(entry[2])
-        self.checksum: str = entry[3]
-        self.tags: list[str] = []
+        self.checksum: str = str(entry[3])
+        self.tags: set[str] = set()
 
-        if isinstance(entry[4], list):
+        if isinstance(entry[4], set):
             self.tags = entry[4]
-        else:
+            self.__remove_invalid()
+        elif isinstance(entry[4], str):
             if entry[4] != '-':
-                self.tags = entry[4].split(',')
+                self.tags = set(e.strip() for e in str(entry[4]).split(','))
+                self.__remove_invalid()
+        # this should be unreachable as the type has to be str or set[str],
+        #   but I have just in case to evade bugs
+        else:
+            log.error('tags has to be either a set or string (comma separated)')
+            sys.exit(1)
 
         log.debug('"%s" tags: %s', self.fname, self.tags)
 
     def __str__(self) -> str:
-        _return_str: str = '[{}, {}, {}, {}, {}]'\
+        _return_str: str = "['{}', {}, {}, '{}', {}]"\
             .format(self.fname,
                     self.ctimestamp,
                     self.mtimestamp,
                     self.checksum,
-                    self.tags)
+                    sorted(self.tags))
         return _return_str
 
+    def __remove_invalid(self) -> None:
+        if '-' in self.tags:
+            self.tags.remove('-')
+
     # used for csv writing
     def get_raw_entry(self) -> list[str]:
         return [self.fname,
                 str(self.ctimestamp),
                 str(self.mtimestamp),
                 self.checksum,
-                ','.join(self.tags) if self.tags else '-']
+                ','.join(sorted(self.tags)) if self.tags else '-']
 
-    def update_tags(self, new_tags: list[str]) -> None:
+    def update_tags(self, new_tags: set[str]) -> None:
         self.tags = new_tags
+        self.__remove_invalid()
diff --git a/tests/test_database_entry.py b/tests/test_database_entry.py
new file mode 100644
index 0000000..1320577
--- /dev/null
+++ b/tests/test_database_entry.py
@@ -0,0 +1,81 @@
+import pytest
+from typing import Any
+from logging import ERROR
+from pytest import LogCaptureFixture
+from pyssg.database_entry import DatabaseEntry
+
+
+@pytest.mark.parametrize('entry, exp_str', [
+    (('t', 0.0, 0.0, '1', set()), "['t', 0.0, 0.0, '1', []]"),
+    (('t', 0, 1, '1', set()), "['t', 0.0, 1.0, '1', []]"),
+    (('t', 0.0, 0.0, '1', '-'), "['t', 0.0, 0.0, '1', []]"),
+    (('t', 0.0, 0.0, 1, '-'), "['t', 0.0, 0.0, '1', []]"),
+    (('t', 0.0, 0.0, '1', {'-', 'tag'}), "['t', 0.0, 0.0, '1', ['tag']]"),
+    (('t', 0.0, 0.0, '1', '-,tag'), "['t', 0.0, 0.0, '1', ['tag']]"),
+    (('t', 0.0, 0.0, '1', 'tag,-,-'), "['t', 0.0, 0.0, '1', ['tag']]"),
+    (('t', 0.0, 0.0, '1', 'tag1,tag2'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2']]"),
+    (('t', 0.0, 0.0, '1', {'tag1', 'tag2'}), "['t', 0.0, 0.0, '1', ['tag1', 'tag2']]"),
+    (('t', 0.0, 0.0, '1', ' tag1 , tag2,tag3'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2', 'tag3']]"),
+    (('t', 0.0, 0.0, '1', 'tag3,tag2,tag1'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2', 'tag3']]"),
+    (('t', 0.0, 0.0, '1', 'tag2,tag3,tag1'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2', 'tag3']]")
+])
+def test_db_entry_obj(entry: tuple[str, float, float, str, str | set[str]],
+                      exp_str: str) -> None:
+    db_entry: DatabaseEntry = DatabaseEntry(entry)
+    assert str(db_entry) == exp_str
+
+
+@pytest.mark.parametrize('entry, exp_str', [
+    (('t', 0.0, 0.0, '1', set()), ['t', '0.0', '0.0', '1', '-']),
+    (('t', 0, 1, '1', set()), ['t', '0.0', '1.0', '1', '-']),
+    (('t', 0.0, 0.0, '1', '-'), ['t', '0.0', '0.0', '1', '-']),
+    (('t', 0.0, 0.0, 1, '-'), ['t', '0.0', '0.0', '1', '-']),
+    (('t', 0.0, 0.0, '1', '-,tag'), ['t', '0.0', '0.0', '1', 'tag']),
+    (('t', 0.0, 0.0, '1', {'-', 'tag'}), ['t', '0.0', '0.0', '1', 'tag']),
+    (('t', 0.0, 0.0, '1', 'tag,-,-'), ['t', '0.0', '0.0', '1', 'tag']),
+    (('t', 0.0, 0.0, '1', 'tag1,tag2'), ['t', '0.0', '0.0', '1', 'tag1,tag2']),
+    (('t', 0.0, 0.0, '1', {'tag1', 'tag2'}), ['t', '0.0', '0.0', '1', 'tag1,tag2']),
+    (('t', 0.0, 0.0, '1', ' tag1 , tag2,tag3'), ['t', '0.0', '0.0', '1', 'tag1,tag2,tag3']),
+    (('t', 0.0, 0.0, '1', 'tag3,tag2,tag1'), ['t', '0.0', '0.0', '1', 'tag1,tag2,tag3']),
+    (('t', 0.0, 0.0, '1', 'tag2,tag3,tag1'), ['t', '0.0', '0.0', '1', 'tag1,tag2,tag3'])
+])
+def test_db_entry_get_raw(entry: tuple[str, float, float, str, str | set[str]],
+                          exp_str: list[str]) -> None:
+    db_entry: DatabaseEntry = DatabaseEntry(entry)
+    db_entry_raw: list[str] = db_entry.get_raw_entry()
+    assert db_entry_raw == exp_str
+
+
+# not sure if this is enough to test tag updating,
+#   it's a bit redundant as the set functionality is doing all the work
+@pytest.mark.parametrize('new_tags', [
+    ({'tag'}),
+    ({'tag1', 'tag2'}),
+    ({'tag1', 'tag2', 'tag3'}),
+    ({'-'}),
+    ({'-', '-'}),
+    (set()),
+    ({'-', '-'}),
+])
+def test_db_entry_update_tags(new_tags: set[str]) -> None:
+    db_entry: DatabaseEntry = DatabaseEntry(('t', 0.0, 0.0, '1', {'just', 'something'}))
+    db_entry.update_tags(new_tags)
+    assert db_entry.tags == new_tags
+
+
+# just a few random tests for things that are not str or set
+@pytest.mark.parametrize('tags', [
+    ({}),
+    (tuple()),
+    (1),
+    (1.0),
+])
+def test_db_entry_bad_tags(tags: Any, caplog: LogCaptureFixture) -> None:
+    err: tuple[str, int, str] = ('pyssg.database_entry',
+                                 ERROR,
+                                 'tags has to be either a set or string (comma separated)')
+    with pytest.raises(SystemExit) as system_exit:
+        DatabaseEntry(('t', 0.0, 0.0, '1', tags))
+    assert system_exit.type == SystemExit
+    assert system_exit.value.code == 1
+    assert caplog.record_tuples[-1] == err
diff --git a/tests/test_utils.py b/tests/test_utils.py
index b7c9754..86242c2 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -151,7 +151,7 @@ def test_dir_structure(tmp_dir_structure: Path,
     (('md',), ['first'], ['f0.md', 'second/f4.md', 'second/s1/f5.md']),
     (('md',), ['first', 's1'], ['f0.md', 'second/f4.md']),
     (('md',), ['f2', 's1'], ['f0.md', 'second/f4.md',
-                             'first/f1.md', 'first/f1/f2.md',]),
+                             'first/f1.md', 'first/f1/f2.md',])
 ])
 def test_file_list(tmp_dir_structure: Path,
                    exts: tuple[str],
-- 
cgit v1.2.3-70-g09d2