summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDavid Luevano Alvarado <david@luevano.xyz>2023-02-24 21:37:37 -0600
committerDavid Luevano Alvarado <david@luevano.xyz>2023-02-24 21:37:37 -0600
commiteae1a3a5f602e9c79e07f4b6c1b133dfd12e1d5c (patch)
tree5ac9e4e20524b83b2020f985ca415ddbb341f097 /src
parenta609b1cb2b43fd17e03efa62314f679b47ae6cb5 (diff)
add database_entry tests, change type for tags attr
Diffstat (limited to 'src')
-rw-r--r--src/pyssg/database.py31
-rw-r--r--src/pyssg/database_entry.py38
2 files changed, 44 insertions, 25 deletions
diff --git a/src/pyssg/database.py b/src/pyssg/database.py
index 5d7d71d..ae0a8d4 100644
--- a/src/pyssg/database.py
+++ b/src/pyssg/database.py
@@ -20,8 +20,9 @@ class Database:
self.e: dict[str, DatabaseEntry] = dict()
def update_tags(self, file_name: str,
- new_tags: list[str]) -> None:
- # technically, I should ensure this function can only run if self.e is populated
+ new_tags: set[str]) -> None:
+ # technically, I should ensure this function can only run
+ # if self.e is populated
if file_name in self.e:
log.debug('updating tags for entry "%s"', file_name)
log.debug('entry "%s" old tags: %s',
@@ -39,7 +40,7 @@ class Database:
remove: str = '') -> None:
log.debug('updating entry for file "%s"', file_name)
f: str = file_name
- tags: list[str] = []
+ tags: set[str] = set()
if remove != '':
f = file_name.replace(remove, '')
log.debug('removed "%s" from "%s": "%s"', remove, file_name, f)
@@ -49,8 +50,8 @@ class Database:
log.debug('time for "%s": %s', file_name, time)
# calculate current checksum, also needs actual file name
- checksum: str = get_checksum(file_name)
- log.debug('checksum for "%s": "%s"', file_name, checksum)
+ cksm: str = get_checksum(file_name)
+ log.debug('checksum for "%s": "%s"', file_name, cksm)
# three cases, 1) entry didn't exist,
# 2) entry has been mod and,
@@ -58,7 +59,7 @@ class Database:
# 1)
if f not in self.e:
log.debug('entry "%s" didn\'t exist, adding with defaults', f)
- self.e[f] = DatabaseEntry([f, time, 0.0, checksum, tags])
+ self.e[f] = DatabaseEntry((f, time, 0.0, cksm, tags))
return
# old_e is old entity
@@ -66,10 +67,10 @@ class Database:
log.debug('entry "%s" old content: %s', f, old_e)
# 2)
- if checksum != old_e.checksum:
+ if cksm != old_e.checksum:
log.debug('entry "%s" has been modified, updating', f)
- self.e[f] = DatabaseEntry([f, old_e.ctimestamp, time, checksum, tags])
- log.debug('entry "%s" new content: (%s, %s, %s, (%s))', f, self.e[f])
+ self.e[f] = DatabaseEntry((f, old_e.ctimestamp, time, cksm, tags))
+ log.debug('entry "%s" new content: %s', f, self.e[f])
return
# 3)
else:
@@ -96,7 +97,7 @@ class Database:
sys.exit(1)
return True
- def _get_csv_rows(self) -> list[list[str]]:
+ def _get_raw_csv_rows(self) -> list[list[str]]:
rows: list[list[str]]
with open(self.db_path, 'r') as f:
csv_reader = csv.reader(f, delimiter=self.__COLUMN_DELIMITER)
@@ -110,7 +111,7 @@ class Database:
if not self._db_path_exists():
return
- rows: list[list[str]] = self._get_csv_rows()
+ rows: list[list[str]] = self._get_raw_csv_rows()
# l=list of values in entry
log.debug('parsing rows from db')
for it, row in enumerate(rows):
@@ -122,5 +123,11 @@ class Database:
' columns: "%s"',
i, self.__COLUMN_NUM, col_num, row)
sys.exit(1)
- entry: DatabaseEntry = DatabaseEntry(row)
+ # actual value types
+ r: tuple[str, float, float, str, str] = (str(row[0]),
+ float(row[1]),
+ float(row[2]),
+ str(row[3]),
+ str(row[4]))
+ entry: DatabaseEntry = DatabaseEntry(r)
self.e[entry.fname] = entry
diff --git a/src/pyssg/database_entry.py b/src/pyssg/database_entry.py
index 5c3e659..58e9884 100644
--- a/src/pyssg/database_entry.py
+++ b/src/pyssg/database_entry.py
@@ -1,42 +1,54 @@
+import sys
from logging import Logger, getLogger
log: Logger = getLogger(__name__)
class DatabaseEntry:
- # not specifying the type of "list" as it could be only str
- # or the actual values
- def __init__(self, entry: list) -> None:
- self.fname: str = entry[0]
+ # ignoring return type as it makes the line too long, unnecessary, too
+ def __init__(self, entry: tuple[str, float, float, str, str | set[str]]):
+ self.fname: str = str(entry[0])
self.ctimestamp: float = float(entry[1])
self.mtimestamp: float = float(entry[2])
- self.checksum: str = entry[3]
- self.tags: list[str] = []
+ self.checksum: str = str(entry[3])
+ self.tags: set[str] = set()
- if isinstance(entry[4], list):
+ if isinstance(entry[4], set):
self.tags = entry[4]
- else:
+ self.__remove_invalid()
+ elif isinstance(entry[4], str):
if entry[4] != '-':
- self.tags = entry[4].split(',')
+ self.tags = set(e.strip() for e in str(entry[4]).split(','))
+ self.__remove_invalid()
+ # this should be unreachable as the type has to be str or set[str],
+ # but I have just in case to evade bugs
+ else:
+ log.error('tags has to be either a set or string (comma separated)')
+ sys.exit(1)
log.debug('"%s" tags: %s', self.fname, self.tags)
def __str__(self) -> str:
- _return_str: str = '[{}, {}, {}, {}, {}]'\
+ _return_str: str = "['{}', {}, {}, '{}', {}]"\
.format(self.fname,
self.ctimestamp,
self.mtimestamp,
self.checksum,
- self.tags)
+ sorted(self.tags))
return _return_str
+ def __remove_invalid(self) -> None:
+ if '-' in self.tags:
+ self.tags.remove('-')
+
# used for csv writing
def get_raw_entry(self) -> list[str]:
return [self.fname,
str(self.ctimestamp),
str(self.mtimestamp),
self.checksum,
- ','.join(self.tags) if self.tags else '-']
+ ','.join(sorted(self.tags)) if self.tags else '-']
- def update_tags(self, new_tags: list[str]) -> None:
+ def update_tags(self, new_tags: set[str]) -> None:
self.tags = new_tags
+ self.__remove_invalid()