From dfc3e6db921815416b8edc5892b2a7adfc677a25 Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <david@luevano.xyz>
Date: Sat, 23 Apr 2022 20:47:19 -0600
Subject: add checksum checking for mod files instead of timestamp

---
 src/pyssg/arg_parser.py |   3 +
 src/pyssg/builder.py    |   3 +-
 src/pyssg/database.py   | 165 +++++++++++++++++++++++++++++++++---------------
 src/pyssg/pyssg.py      |  13 +++-
 src/pyssg/utils.py      |  18 +++++-
 5 files changed, 145 insertions(+), 57 deletions(-)

(limited to 'src')

diff --git a/src/pyssg/arg_parser.py b/src/pyssg/arg_parser.py
index ec150fb..2fc6853 100644
--- a/src/pyssg/arg_parser.py
+++ b/src/pyssg/arg_parser.py
@@ -37,6 +37,9 @@ def get_parsed_arguments() -> Namespace:
     parser.add_argument('--debug',
                         action='store_true',
                         help='''change logging level from info to debug''')
+    parser.add_argument('--add-checksum-to-db',
+                        action='store_true',
+                        help='''add checksum column to db entries''')
     # really not needed, too much bloat and case scenarios to check for,
     #   instead, just read from config file or default config file
     """
diff --git a/src/pyssg/builder.py b/src/pyssg/builder.py
index 35502b0..6d65187 100644
--- a/src/pyssg/builder.py
+++ b/src/pyssg/builder.py
@@ -83,7 +83,8 @@ class Builder:
         dir_path: str = None
         for d in self.dirs:
             dir_path = os.path.join(self.config.get('path', 'dst'), d)
-            create_dir(dir_path, True)
+            # using silent=True to not print the info create dir msgs for this
+            create_dir(dir_path, True, True)
 
 
     def __copy_html_files(self) -> None:
diff --git a/src/pyssg/database.py b/src/pyssg/database.py
index 66c7087..290ba51 100644
--- a/src/pyssg/database.py
+++ b/src/pyssg/database.py
@@ -2,7 +2,6 @@ import os
 import sys
 from logging import Logger, getLogger
 from configparser import ConfigParser
-from tabnanny import check
 
 from .utils import get_checksum
 
@@ -11,14 +10,15 @@ log: Logger = getLogger(__name__)
 
 # db class that works for both html and md files
 class Database:
-    __COLUMN_NUM: int = 4
+    __OLD_COLUMN_NUM: int = 4
+    __COLUMN_NUM: int = 5
 
     def __init__(self, db_path: str,
                  config: ConfigParser):
         log.debug('initializing the page db on path "%s"', db_path)
         self.db_path: str = db_path
         self.config: ConfigParser = config
-        self.e: dict[str, tuple[float, float, list[str]]] = dict()
+        self.e: dict[str, tuple[float, float, str, list[str]]] = dict()
 
 
     # updates the tags for a specific entry (file)
@@ -27,12 +27,12 @@ class Database:
                     tags: list[str]) -> None:
         if file_name in self.e:
             log.debug('updating tags for entry "%s"', file_name)
-            cts, mts, old_tags = self.e[file_name]
-            log.debug('entry "%s" old content: (%s, %s, (%s))',
-                      file_name, cts, mts, ', '.join(old_tags))
-            self.e[file_name] = (cts, mts, tags)
-            log.debug('entry "%s" new content: (%s, %s, (%s))',
-                      file_name, cts, mts, ', '.join(tags))
+            cts, mts, checksum, old_tags = self.e[file_name]
+            log.debug('entry "%s" old content: (%s, %s, %s, (%s))',
+                      file_name, cts, mts, checksum, ', '.join(old_tags))
+            self.e[file_name] = (cts, mts, checksum, tags)
+            log.debug('entry "%s" new content: (%s, %s, %s, (%s))',
+                      file_name, cts, mts, checksum, ', '.join(tags))
         else:
             log.error('can\'t update tags for entry "%s",'
                       ' as it is not present in db', file_name)
@@ -51,44 +51,42 @@ class Database:
             f = file_name.replace(remove, '')
             log.debug('removed "%s" from "%s": "%s"', remove, file_name, f)
 
-
         # get current time, needs actual file name
         time: float = os.stat(file_name).st_mtime
         log.debug('modified time for "%s": %s', file_name, time)
 
-        # three cases, 1) entry didn't exist,
-        # 2) entry hasn't been mod and,
-        # 3) entry has been mod
+        # calculate current checksum, also needs actual file name
+        checksum: str = get_checksum(file_name)
+        log.debug('current checksum for "%s": "%s"', file_name, checksum)
+
+        # two cases, 1) entry didn't exist,
+        # 2) entry has been mod and,
+        # 3) entry hasn't been mod
         #1)
         if f not in self.e:
             log.debug('entry "%s" didn\'t exist, adding with defaults', f)
-            self.e[f] = (time, 0.0, tags)
+            self.e[f] = (time, 0.0, checksum, tags)
             return True
 
-        old_time, old_mod_time, tags = self.e[f]
-        log.debug('entry "%s" old content: (%s, %s, (%s))',
-                  f, old_time, old_mod_time, ', '.join(tags))
+        old_time, old_mod_time, old_checksum, tags = self.e[f]
+        log.debug('entry "%s" old content: (%s, %s, %s, (%s))',
+                  f, old_time, old_mod_time, old_checksum, ', '.join(tags))
 
         # 2)
-        if old_mod_time == 0.0:
-            if time > old_time:
+        if checksum != old_checksum:
+            if old_mod_time == 0.0:
                 log.debug('entry "%s" has been modified for the first'
                           ' time, updating', f)
-                self.e[f] = (old_time, time, tags)
-                log.debug('entry "%s" new content: (%s, %s, (%s))',
-                          f, old_time, time, ', '.join(tags))
-                return True
+            else:
+                log.debug('entry "%s" has been modified, updating', f)
+            self.e[f] = (old_time, time, checksum, tags)
+            log.debug('entry "%s" new content: (%s, %s, %s, (%s))',
+                      f, old_time, time, checksum, ', '.join(tags))
+            return True
         # 3)
         else:
-            if time > old_mod_time:
-                log.debug('entry "%s" has been modified, updating', f)
-                self.e[f] = (old_time, time, tags)
-                log.debug('entry "%s" new content: (%s, %s, (%s))',
-                          f, old_time, time, ', '.join(tags))
-                return True
-
-        log.debug('entry "%s" hasn\'t been modified', f)
-        return False
+            log.debug('entry "%s" hasn\'t been modified', f)
+            return False
 
 
     def write(self) -> None:
@@ -98,54 +96,117 @@ class Database:
                 log.debug('parsing row for page "%s"', k)
                 t: str = None
                 row: str = None
-                if len(v[2]) == 0:
+                if len(v[3]) == 0:
                     t = '-'
                 else:
-                    t = ','.join(v[2])
+                    t = ','.join(v[3])
 
-                row = f'{k} {v[0]} {v[1]} {t}'
+                row = f'{k} {v[0]} {v[1]} {v[2]} {t}'
 
                 log.debug('writing row: "%s\\n"', row)
                 file.write(f'{row}\n')
 
 
-    def read(self) -> None:
-        log.debug('reading db')
+    def _db_path_exists(self) -> bool:
+        log.debug('checking that "%s" exists or is a file', self.db_path)
         if not os.path.exists(self.db_path):
             log.warning('"%s" doesn\'t exist, will be'
                         ' created once process finishes,'
                         ' ignore if it\'s the first run', self.db_path)
-            return
+            return False
 
-        if os.path.exists(self.db_path) and not os.path.isfile(self.db_path):
+        if not os.path.isfile(self.db_path):
             log.error('"%s" is not a file"', self.db_path)
             sys.exit(1)
 
+        return True
+
+
+    def _read_raw(self) -> list[str]:
         rows: list[str] = None
         with open(self.db_path, 'r') as file:
             rows = file.readlines()
-        log.info('db contains %d rows', len(rows))
+        log.debug('db contains %d rows', len(rows))
+
+        return rows
+
+
+    def read_old(self) -> None:
+        log.debug('reading db with old schema (%d columns)', self.__OLD_COLUMN_NUM)
+        if not self._db_path_exists():
+            log.error('db path "%s" desn\'t exist, --add-checksum-to-db should'
+                      'only be used when updating the old db schema', self.db_path)
+            sys.exit(1)
+
+        rows: list[str] = self._read_raw()
+        cols: list[str] = None
+        # l=list of values in entry
+        log.debug('parsing rows from db')
+        for it, row in enumerate(rows):
+            i: int = it + 1
+            r: str = row.strip()
+            log.debug('row %d content: "%s"', i, r)
+            # (file_name, ctimestamp, mtimestamp, [tags])
+            cols: tuple[str, float, float, list[str]] = tuple(r.split())
+            col_num: int = len(cols)
+            if col_num != self.__OLD_COLUMN_NUM:
+                log.critical('row %d doesn\'t contain %s columns, contains %d'
+                             ' columns: "%s"',
+                             i, self.__OLD_COLUMN_NUM, col_num, r)
+                sys.exit(1)
+
+            t: list[str] = None
+            if cols[3] == '-':
+                t = []
+            else:
+                t = cols[3].split(',')
+            log.debug('tag content: (%s)', ', '.join(t))
+            file_path: str = os.path.join(self.config.get('path', 'src'), cols[0])
+            checksum: str = get_checksum(file_path)
+            log.debug('checksum for "%s": "%s"', file_path, checksum)
 
-        # parse each entry and populate accordingly
-        l: list[str] = None
+            self.e[cols[0]] = (float(cols[1]), float(cols[2]), checksum, t)
+
+
+
+    def read(self) -> None:
+        log.debug('reading db')
+        if not self._db_path_exists():
+            return
+
+        rows: list[str] = self._read_raw()
+        cols: list[str] = None
         # l=list of values in entry
         log.debug('parsing rows from db')
         for it, row in enumerate(rows):
-            i = it + 1
-            r = row.strip()
+            i: int = it + 1
+            r: str = row.strip()
             log.debug('row %d content: "%s"', i, r)
-            l = tuple(r.split())
-            if len(l) != self.__COLUMN_NUM:
-                log.critical('row %d doesn\'t contain %s columns,'
-                             ' contains %d elements; row %d content: "%s"',
-                             i, self.__COLUMN_NUM, len(l), i, r)
+            # (file_name, ctimestamp, mtimestamp, checksum, [tags])
+            cols: tuple[str, float, float, str, list[str]] = tuple(r.split())
+            col_num: int = len(cols)
+            if col_num == self.__OLD_COLUMN_NUM:
+                log.error('row %d contains %d columns: "%s"; this is probably'
+                          ' because of missing checksum column, which is used'
+                          ' now to also check if a file has changed. Rerun'
+                          ' with flag --add-checksum-to-db to add the checksum'
+                          ' column to the current db; if you did any changes'
+                          ' since last timestamp in db, it won\'t update'
+                          ' modification timestamp',
+                          i, self.__OLD_COLUMN_NUM, r)
+                sys.exit(1)
+
+            if col_num != self.__COLUMN_NUM:
+                log.critical('row %d doesn\'t contain %s columns, contains %d'
+                             ' columns: "%s"',
+                             i, self.__COLUMN_NUM, col_num, r)
                 sys.exit(1)
 
             t: list[str] = None
-            if l[3] == '-':
+            if cols[4] == '-':
                 t = []
             else:
-                t = l[3].split(',')
+                t = cols[4].split(',')
             log.debug('tag content: (%s)', ', '.join(t))
 
-            self.e[l[0]] = (float(l[1]), float(l[2]), t)
+            self.e[cols[0]] = (float(cols[1]), float(cols[2]), cols[3], t)
diff --git a/src/pyssg/pyssg.py b/src/pyssg/pyssg.py
index af7b166..598bf41 100644
--- a/src/pyssg/pyssg.py
+++ b/src/pyssg/pyssg.py
@@ -56,6 +56,7 @@ def main() -> None:
         sys.exit(1)
 
     config: ConfigParser = get_parsed_config(config_path)
+    config.set('info', 'debug', str(args['debug']))
 
     if args['init']:
         log.info('initializing the directory structure and copying over templates')
@@ -74,8 +75,18 @@ def main() -> None:
                 copy_file(p, plt_file)
         sys.exit(0)
 
+    if args['add_checksum_to_db']:
+        log.info('adding checksum column to existing db')
+        db_path: str = os.path.join(config.get('path', 'src'), '.files')
+        db: Database = Database(db_path, config)
+        # needs to be read_old instead of read
+        db.read_old()
+        db.write()
+
+        sys.exit(0)
+
     if args['build']:
-        log.debug('building the html files')
+        log.info('building the html files')
         db_path: str = os.path.join(config.get('path', 'src'), '.files')
         db: Database = Database(db_path, config)
         db.read()
diff --git a/src/pyssg/utils.py b/src/pyssg/utils.py
index ffaf8ba..a41249a 100644
--- a/src/pyssg/utils.py
+++ b/src/pyssg/utils.py
@@ -1,6 +1,7 @@
 import os
 import sys
 import shutil
+from hashlib import md5
 from logging import Logger, getLogger
 
 log: Logger = getLogger(__name__)
@@ -54,15 +55,15 @@ def get_dir_structure(path: str,
     return [o.replace(path, '')[1:] for o in out]
 
 
-def create_dir(path: str, p: bool=False) -> None:
+def create_dir(path: str, p: bool=False, silent=False) -> None:
     try:
         if p:
             os.makedirs(path)
         else:
             os.mkdir(path)
-        log.info('created directory "%s"', path)
+        if not silent: log.info('created directory "%s"', path)
     except FileExistsError:
-        log.info('directory "%s" already exists, ignoring', path)
+        if not silent: log.info('directory "%s" already exists, ignoring', path)
 
 
 def copy_file(src: str, dst: str) -> None:
@@ -78,3 +79,14 @@ def sanity_check_path(path: str) -> None:
         log.error('"$" character found in path "%s";'
                   ' could be due to non-existant env var.', path)
         sys.exit(1)
+
+
+# as seen in SO: https://stackoverflow.com/a/1131238
+def get_checksum(path: str) -> str:
+    log.debug('calculating md5 checksum for "%s"', path)
+    file_hash = md5()
+    with open(path, "rb") as f:
+        while chunk := f.read(4096):
+            file_hash.update(chunk)
+
+    return file_hash.hexdigest()
\ No newline at end of file
-- 
cgit v1.2.3-70-g09d2