From e28fbb181851ca16bc0ade9c371628f86c25adbc Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <david@luevano.xyz>
Date: Sun, 19 Feb 2023 22:59:32 -0600
Subject: refactor custom logger and add tests

---
 src/pyssg/__init__.py            | 16 +++-----------
 src/pyssg/custom_logger.py       | 45 ++++++++++++++++++++++++++++++++++++++++
 src/pyssg/per_level_formatter.py | 29 --------------------------
 tests/conftest.py                | 23 ++++++++++++++++++++
 tests/io_files/__init__.py       |  0
 tests/test_arg_parser.py         |  1 -
 tests/test_custom_logger.py      | 20 ++++++++++++++++++
 7 files changed, 91 insertions(+), 43 deletions(-)
 create mode 100644 src/pyssg/custom_logger.py
 delete mode 100644 src/pyssg/per_level_formatter.py
 create mode 100644 tests/io_files/__init__.py
 create mode 100644 tests/test_custom_logger.py

diff --git a/src/pyssg/__init__.py b/src/pyssg/__init__.py
index a4e5857..5d112f1 100644
--- a/src/pyssg/__init__.py
+++ b/src/pyssg/__init__.py
@@ -1,17 +1,7 @@
-from logging import Logger, StreamHandler, getLogger, INFO
-
 from .pyssg import main
-from .per_level_formatter import PerLevelFormatter
-
+from .custom_logger import setup_logger
 
-# since this is the root package, setup the logger here
-__LOG_LEVEL: int = INFO
-log: Logger = getLogger(__name__)
-log.setLevel(__LOG_LEVEL)
-ch: StreamHandler = StreamHandler()
-ch.setLevel(__LOG_LEVEL)
-ch.setFormatter(PerLevelFormatter())
-log.addHandler(ch)
 
+setup_logger()
 # not meant to be used as a package, so just give main
-__all__ = ['main']
\ No newline at end of file
+__all__ = ['main']
diff --git a/src/pyssg/custom_logger.py b/src/pyssg/custom_logger.py
new file mode 100644
index 0000000..55f3e2d
--- /dev/null
+++ b/src/pyssg/custom_logger.py
@@ -0,0 +1,45 @@
+import sys
+from logging import (Logger, StreamHandler, Formatter, LogRecord,
+                     DEBUG, INFO, WARNING, ERROR, CRITICAL,
+                     getLogger)
+
+LOG_LEVEL: int = INFO
+# 'pyssg' es the name of the root logger
+LOGGER_NAME: str = 'pyssg'
+
+
+# only reason for this class is to get info formatting as normal text
+#   and everything else with more info and with colors
+class PerLevelFormatter(Formatter):
+    # colors for the terminal in ansi
+    __YELLOW: str = '\x1b[33m'
+    __RED: str = '\x1b[31m'
+    __BOLD_RED: str = '\x1b[31;1m'
+    __RESET: str = '\x1b[0m'
+
+    __DATE_FMT: str = '%Y-%m-%d %H:%M:%S'
+    __COMMON_FMT: str = '[%(levelname)s] [%(module)s:%(funcName)s:%(lineno)d]: %(message)s'
+    __FORMATS: dict[int, str] = {
+        DEBUG: __COMMON_FMT,
+        INFO: '%(message)s',
+        WARNING: f'{__YELLOW}{__COMMON_FMT}{__RESET}',
+        ERROR: f'{__RED}{__COMMON_FMT}{__RESET}',
+        CRITICAL: f'{__BOLD_RED}{__COMMON_FMT}{__RESET}'
+    }
+
+    def format(self, record: LogRecord) -> str:
+        # this should never fail, as __FORMATS is defined above,
+        #   so no issue of just converting to str
+        fmt: str = str(self.__FORMATS.get(record.levelno))
+        formatter: Formatter = Formatter(
+            fmt=fmt, datefmt=self.__DATE_FMT, style='%')
+        return formatter.format(record)
+
+
+def setup_logger(name: str = LOGGER_NAME, level: int = LOG_LEVEL) -> None:
+    logger: Logger = getLogger(name)
+    handler: StreamHandler = StreamHandler(sys.stdout)
+    logger.setLevel(level)
+    handler.setLevel(level)
+    handler.setFormatter(PerLevelFormatter())
+    logger.addHandler(handler)
diff --git a/src/pyssg/per_level_formatter.py b/src/pyssg/per_level_formatter.py
deleted file mode 100644
index e3b6977..0000000
--- a/src/pyssg/per_level_formatter.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from logging import Formatter, LogRecord, DEBUG, INFO, WARNING, ERROR, CRITICAL
-
-
-# only reason for this class is to get info formatting as normal text
-#   and everything else with more info and with colors
-class PerLevelFormatter(Formatter):
-    # colors for the terminal in ansi
-    __YELLOW: str = '\x1b[33m'
-    __RED: str = '\x1b[31m'
-    __BOLD_RED: str = '\x1b[31;1m'
-    __RESET: str = '\x1b[0m'
-
-    __DATE_FMT: str = '%Y-%m-%d %H:%M:%S'
-    __COMMON_FMT: str = '[%(levelname)s] [%(module)s:%(funcName)s:%(lineno)d]: %(message)s'
-    __FORMATS: dict[int, str] = {
-        DEBUG: __COMMON_FMT,
-        INFO: '%(message)s',
-        WARNING: f'{__YELLOW}{__COMMON_FMT}{__RESET}',
-        ERROR: f'{__RED}{__COMMON_FMT}{__RESET}',
-        CRITICAL: f'{__BOLD_RED}{__COMMON_FMT}{__RESET}'
-    }
-
-    def format(self, record: LogRecord) -> str:
-        # this should never fail, as __FORMATS is defined above,
-        #   so no issue of just converting to str
-        fmt: str = str(self.__FORMATS.get(record.levelno))
-        formatter: Formatter = Formatter(
-            fmt=fmt, datefmt=self.__DATE_FMT, style='%')
-        return formatter.format(record)
diff --git a/tests/conftest.py b/tests/conftest.py
index fcf4189..58416ea 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,30 @@
+import sys
+from typing import Callable
 import pytest
+from logging import getLogger, DEBUG
+
 from pyssg.arg_parser import get_parser
+from pyssg.custom_logger import setup_logger
 
 
 @pytest.fixture(scope='session')
 def arg_parser():
     return get_parser()
+
+
+@pytest.fixture(scope='session')
+def logger():
+    setup_logger(__name__, DEBUG)
+    return getLogger(__name__)
+
+
+@pytest.fixture
+def capture_stdout(monkeypatch: Callable) -> dict[str, str | int]:
+    buffer: dict[str, str | int] = {'stdout': '', 'write_calls': 0}
+
+    def fake_writer(s):
+        buffer['stdout'] += s
+        buffer['write_calls'] += 1  # type: ignore
+
+    monkeypatch.setattr(sys.stdout, 'write', fake_writer)
+    return buffer
diff --git a/tests/io_files/__init__.py b/tests/io_files/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_arg_parser.py b/tests/test_arg_parser.py
index 1a5b48c..e233e93 100644
--- a/tests/test_arg_parser.py
+++ b/tests/test_arg_parser.py
@@ -31,7 +31,6 @@ def test_valid_args(args: list[str],
 @pytest.mark.parametrize('args', [
     (['--something-random']),
     (['-z']),
-    (['hello']),
     (['help']),
     (['h'])
 ])
diff --git a/tests/test_custom_logger.py b/tests/test_custom_logger.py
new file mode 100644
index 0000000..1062e41
--- /dev/null
+++ b/tests/test_custom_logger.py
@@ -0,0 +1,20 @@
+import pytest
+from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL
+from typing import Callable
+
+
+@pytest.mark.parametrize('log_level, starts_with, message', [
+    (DEBUG, '[DEBUG]', 'first message'),
+    (INFO, 'second message', 'second message'),
+    (WARNING, '\x1b[33m[WARNING]', 'third message'),
+    (ERROR, '\x1b[31m[ERROR]', 'fourth message'),
+    (CRITICAL, '\x1b[31;1m[CRITICAL]', 'fifth message'),
+])
+def test_log_levels(log_level: int,
+                    starts_with: str,
+                    message: str,
+                    logger: Callable,
+                    capture_stdout: dict[str, str | int]) -> None:
+    logger.log(log_level, message)
+    assert str(capture_stdout['stdout']).startswith(starts_with)
+    assert message in str(capture_stdout['stdout'])
-- 
cgit v1.2.3-70-g09d2