From eae1a3a5f602e9c79e07f4b6c1b133dfd12e1d5c Mon Sep 17 00:00:00 2001 From: David Luevano Alvarado Date: Fri, 24 Feb 2023 21:37:37 -0600 Subject: add database_entry tests, change type for tags attr --- setup.cfg | 1 + src/pyssg/database.py | 31 ++++++++++------- src/pyssg/database_entry.py | 38 ++++++++++++++------- tests/test_database_entry.py | 81 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_utils.py | 2 +- 5 files changed, 127 insertions(+), 26 deletions(-) create mode 100644 tests/test_database_entry.py diff --git a/setup.cfg b/setup.cfg index d61839e..1fb6efd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,6 +62,7 @@ per-file-ignores = __init__.py: W292 arg_parser.py: E501 custom_logger.py: E501 + test_database_entry.py: E501 [pbr] skip_authors = True diff --git a/src/pyssg/database.py b/src/pyssg/database.py index 5d7d71d..ae0a8d4 100644 --- a/src/pyssg/database.py +++ b/src/pyssg/database.py @@ -20,8 +20,9 @@ class Database: self.e: dict[str, DatabaseEntry] = dict() def update_tags(self, file_name: str, - new_tags: list[str]) -> None: - # technically, I should ensure this function can only run if self.e is populated + new_tags: set[str]) -> None: + # technically, I should ensure this function can only run + # if self.e is populated if file_name in self.e: log.debug('updating tags for entry "%s"', file_name) log.debug('entry "%s" old tags: %s', @@ -39,7 +40,7 @@ class Database: remove: str = '') -> None: log.debug('updating entry for file "%s"', file_name) f: str = file_name - tags: list[str] = [] + tags: set[str] = set() if remove != '': f = file_name.replace(remove, '') log.debug('removed "%s" from "%s": "%s"', remove, file_name, f) @@ -49,8 +50,8 @@ class Database: log.debug('time for "%s": %s', file_name, time) # calculate current checksum, also needs actual file name - checksum: str = get_checksum(file_name) - log.debug('checksum for "%s": "%s"', file_name, checksum) + cksm: str = get_checksum(file_name) + log.debug('checksum for "%s": "%s"', file_name, cksm) # three cases, 1) entry didn't exist, # 2) entry has been mod and, @@ -58,7 +59,7 @@ class Database: # 1) if f not in self.e: log.debug('entry "%s" didn\'t exist, adding with defaults', f) - self.e[f] = DatabaseEntry([f, time, 0.0, checksum, tags]) + self.e[f] = DatabaseEntry((f, time, 0.0, cksm, tags)) return # old_e is old entity @@ -66,10 +67,10 @@ class Database: log.debug('entry "%s" old content: %s', f, old_e) # 2) - if checksum != old_e.checksum: + if cksm != old_e.checksum: log.debug('entry "%s" has been modified, updating', f) - self.e[f] = DatabaseEntry([f, old_e.ctimestamp, time, checksum, tags]) - log.debug('entry "%s" new content: (%s, %s, %s, (%s))', f, self.e[f]) + self.e[f] = DatabaseEntry((f, old_e.ctimestamp, time, cksm, tags)) + log.debug('entry "%s" new content: %s', f, self.e[f]) return # 3) else: @@ -96,7 +97,7 @@ class Database: sys.exit(1) return True - def _get_csv_rows(self) -> list[list[str]]: + def _get_raw_csv_rows(self) -> list[list[str]]: rows: list[list[str]] with open(self.db_path, 'r') as f: csv_reader = csv.reader(f, delimiter=self.__COLUMN_DELIMITER) @@ -110,7 +111,7 @@ class Database: if not self._db_path_exists(): return - rows: list[list[str]] = self._get_csv_rows() + rows: list[list[str]] = self._get_raw_csv_rows() # l=list of values in entry log.debug('parsing rows from db') for it, row in enumerate(rows): @@ -122,5 +123,11 @@ class Database: ' columns: "%s"', i, self.__COLUMN_NUM, col_num, row) sys.exit(1) - entry: DatabaseEntry = DatabaseEntry(row) + # actual value types + r: tuple[str, float, float, str, str] = (str(row[0]), + float(row[1]), + float(row[2]), + str(row[3]), + str(row[4])) + entry: DatabaseEntry = DatabaseEntry(r) self.e[entry.fname] = entry diff --git a/src/pyssg/database_entry.py b/src/pyssg/database_entry.py index 5c3e659..58e9884 100644 --- a/src/pyssg/database_entry.py +++ b/src/pyssg/database_entry.py @@ -1,42 +1,54 @@ +import sys from logging import Logger, getLogger log: Logger = getLogger(__name__) class DatabaseEntry: - # not specifying the type of "list" as it could be only str - # or the actual values - def __init__(self, entry: list) -> None: - self.fname: str = entry[0] + # ignoring return type as it makes the line too long, unnecessary, too + def __init__(self, entry: tuple[str, float, float, str, str | set[str]]): + self.fname: str = str(entry[0]) self.ctimestamp: float = float(entry[1]) self.mtimestamp: float = float(entry[2]) - self.checksum: str = entry[3] - self.tags: list[str] = [] + self.checksum: str = str(entry[3]) + self.tags: set[str] = set() - if isinstance(entry[4], list): + if isinstance(entry[4], set): self.tags = entry[4] - else: + self.__remove_invalid() + elif isinstance(entry[4], str): if entry[4] != '-': - self.tags = entry[4].split(',') + self.tags = set(e.strip() for e in str(entry[4]).split(',')) + self.__remove_invalid() + # this should be unreachable as the type has to be str or set[str], + # but I have just in case to evade bugs + else: + log.error('tags has to be either a set or string (comma separated)') + sys.exit(1) log.debug('"%s" tags: %s', self.fname, self.tags) def __str__(self) -> str: - _return_str: str = '[{}, {}, {}, {}, {}]'\ + _return_str: str = "['{}', {}, {}, '{}', {}]"\ .format(self.fname, self.ctimestamp, self.mtimestamp, self.checksum, - self.tags) + sorted(self.tags)) return _return_str + def __remove_invalid(self) -> None: + if '-' in self.tags: + self.tags.remove('-') + # used for csv writing def get_raw_entry(self) -> list[str]: return [self.fname, str(self.ctimestamp), str(self.mtimestamp), self.checksum, - ','.join(self.tags) if self.tags else '-'] + ','.join(sorted(self.tags)) if self.tags else '-'] - def update_tags(self, new_tags: list[str]) -> None: + def update_tags(self, new_tags: set[str]) -> None: self.tags = new_tags + self.__remove_invalid() diff --git a/tests/test_database_entry.py b/tests/test_database_entry.py new file mode 100644 index 0000000..1320577 --- /dev/null +++ b/tests/test_database_entry.py @@ -0,0 +1,81 @@ +import pytest +from typing import Any +from logging import ERROR +from pytest import LogCaptureFixture +from pyssg.database_entry import DatabaseEntry + + +@pytest.mark.parametrize('entry, exp_str', [ + (('t', 0.0, 0.0, '1', set()), "['t', 0.0, 0.0, '1', []]"), + (('t', 0, 1, '1', set()), "['t', 0.0, 1.0, '1', []]"), + (('t', 0.0, 0.0, '1', '-'), "['t', 0.0, 0.0, '1', []]"), + (('t', 0.0, 0.0, 1, '-'), "['t', 0.0, 0.0, '1', []]"), + (('t', 0.0, 0.0, '1', {'-', 'tag'}), "['t', 0.0, 0.0, '1', ['tag']]"), + (('t', 0.0, 0.0, '1', '-,tag'), "['t', 0.0, 0.0, '1', ['tag']]"), + (('t', 0.0, 0.0, '1', 'tag,-,-'), "['t', 0.0, 0.0, '1', ['tag']]"), + (('t', 0.0, 0.0, '1', 'tag1,tag2'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2']]"), + (('t', 0.0, 0.0, '1', {'tag1', 'tag2'}), "['t', 0.0, 0.0, '1', ['tag1', 'tag2']]"), + (('t', 0.0, 0.0, '1', ' tag1 , tag2,tag3'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2', 'tag3']]"), + (('t', 0.0, 0.0, '1', 'tag3,tag2,tag1'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2', 'tag3']]"), + (('t', 0.0, 0.0, '1', 'tag2,tag3,tag1'), "['t', 0.0, 0.0, '1', ['tag1', 'tag2', 'tag3']]") +]) +def test_db_entry_obj(entry: tuple[str, float, float, str, str | set[str]], + exp_str: str) -> None: + db_entry: DatabaseEntry = DatabaseEntry(entry) + assert str(db_entry) == exp_str + + +@pytest.mark.parametrize('entry, exp_str', [ + (('t', 0.0, 0.0, '1', set()), ['t', '0.0', '0.0', '1', '-']), + (('t', 0, 1, '1', set()), ['t', '0.0', '1.0', '1', '-']), + (('t', 0.0, 0.0, '1', '-'), ['t', '0.0', '0.0', '1', '-']), + (('t', 0.0, 0.0, 1, '-'), ['t', '0.0', '0.0', '1', '-']), + (('t', 0.0, 0.0, '1', '-,tag'), ['t', '0.0', '0.0', '1', 'tag']), + (('t', 0.0, 0.0, '1', {'-', 'tag'}), ['t', '0.0', '0.0', '1', 'tag']), + (('t', 0.0, 0.0, '1', 'tag,-,-'), ['t', '0.0', '0.0', '1', 'tag']), + (('t', 0.0, 0.0, '1', 'tag1,tag2'), ['t', '0.0', '0.0', '1', 'tag1,tag2']), + (('t', 0.0, 0.0, '1', {'tag1', 'tag2'}), ['t', '0.0', '0.0', '1', 'tag1,tag2']), + (('t', 0.0, 0.0, '1', ' tag1 , tag2,tag3'), ['t', '0.0', '0.0', '1', 'tag1,tag2,tag3']), + (('t', 0.0, 0.0, '1', 'tag3,tag2,tag1'), ['t', '0.0', '0.0', '1', 'tag1,tag2,tag3']), + (('t', 0.0, 0.0, '1', 'tag2,tag3,tag1'), ['t', '0.0', '0.0', '1', 'tag1,tag2,tag3']) +]) +def test_db_entry_get_raw(entry: tuple[str, float, float, str, str | set[str]], + exp_str: list[str]) -> None: + db_entry: DatabaseEntry = DatabaseEntry(entry) + db_entry_raw: list[str] = db_entry.get_raw_entry() + assert db_entry_raw == exp_str + + +# not sure if this is enough to test tag updating, +# it's a bit redundant as the set functionality is doing all the work +@pytest.mark.parametrize('new_tags', [ + ({'tag'}), + ({'tag1', 'tag2'}), + ({'tag1', 'tag2', 'tag3'}), + ({'-'}), + ({'-', '-'}), + (set()), + ({'-', '-'}), +]) +def test_db_entry_update_tags(new_tags: set[str]) -> None: + db_entry: DatabaseEntry = DatabaseEntry(('t', 0.0, 0.0, '1', {'just', 'something'})) + db_entry.update_tags(new_tags) + assert db_entry.tags == new_tags + + +# just a few random tests for things that are not str or set +@pytest.mark.parametrize('tags', [ + ({}), + (tuple()), + (1), + (1.0), +]) +def test_db_entry_bad_tags(tags: Any, caplog: LogCaptureFixture) -> None: + err: tuple[str, int, str] = ('pyssg.database_entry', + ERROR, + 'tags has to be either a set or string (comma separated)') + with pytest.raises(SystemExit) as system_exit: + DatabaseEntry(('t', 0.0, 0.0, '1', tags)) + assert system_exit.type == SystemExit + assert system_exit.value.code == 1 + assert caplog.record_tuples[-1] == err diff --git a/tests/test_utils.py b/tests/test_utils.py index b7c9754..86242c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -151,7 +151,7 @@ def test_dir_structure(tmp_dir_structure: Path, (('md',), ['first'], ['f0.md', 'second/f4.md', 'second/s1/f5.md']), (('md',), ['first', 's1'], ['f0.md', 'second/f4.md']), (('md',), ['f2', 's1'], ['f0.md', 'second/f4.md', - 'first/f1.md', 'first/f1/f2.md',]), + 'first/f1.md', 'first/f1/f2.md',]) ]) def test_file_list(tmp_dir_structure: Path, exts: tuple[str], -- cgit v1.2.3-54-g00ecf