summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Luevano Alvarado <david@luevano.xyz>2023-02-19 22:59:32 -0600
committerDavid Luevano Alvarado <david@luevano.xyz>2023-02-19 22:59:32 -0600
commite28fbb181851ca16bc0ade9c371628f86c25adbc (patch)
tree35db0440058600b36655956c61e080a08cb9d7ee
parentb2fbb532e359985142a71354b5b648ae560a80ac (diff)
refactor custom logger and add tests
-rw-r--r--src/pyssg/__init__.py16
-rw-r--r--src/pyssg/custom_logger.py (renamed from src/pyssg/per_level_formatter.py)18
-rw-r--r--tests/conftest.py23
-rw-r--r--tests/io_files/__init__.py0
-rw-r--r--tests/test_arg_parser.py1
-rw-r--r--tests/test_custom_logger.py20
6 files changed, 63 insertions, 15 deletions
diff --git a/src/pyssg/__init__.py b/src/pyssg/__init__.py
index a4e5857..5d112f1 100644
--- a/src/pyssg/__init__.py
+++ b/src/pyssg/__init__.py
@@ -1,17 +1,7 @@
-from logging import Logger, StreamHandler, getLogger, INFO
-
from .pyssg import main
-from .per_level_formatter import PerLevelFormatter
-
+from .custom_logger import setup_logger
-# since this is the root package, setup the logger here
-__LOG_LEVEL: int = INFO
-log: Logger = getLogger(__name__)
-log.setLevel(__LOG_LEVEL)
-ch: StreamHandler = StreamHandler()
-ch.setLevel(__LOG_LEVEL)
-ch.setFormatter(PerLevelFormatter())
-log.addHandler(ch)
+setup_logger()
# not meant to be used as a package, so just give main
-__all__ = ['main'] \ No newline at end of file
+__all__ = ['main']
diff --git a/src/pyssg/per_level_formatter.py b/src/pyssg/custom_logger.py
index e3b6977..55f3e2d 100644
--- a/src/pyssg/per_level_formatter.py
+++ b/src/pyssg/custom_logger.py
@@ -1,4 +1,11 @@
-from logging import Formatter, LogRecord, DEBUG, INFO, WARNING, ERROR, CRITICAL
+import sys
+from logging import (Logger, StreamHandler, Formatter, LogRecord,
+ DEBUG, INFO, WARNING, ERROR, CRITICAL,
+ getLogger)
+
+LOG_LEVEL: int = INFO
+# 'pyssg' es the name of the root logger
+LOGGER_NAME: str = 'pyssg'
# only reason for this class is to get info formatting as normal text
@@ -27,3 +34,12 @@ class PerLevelFormatter(Formatter):
formatter: Formatter = Formatter(
fmt=fmt, datefmt=self.__DATE_FMT, style='%')
return formatter.format(record)
+
+
+def setup_logger(name: str = LOGGER_NAME, level: int = LOG_LEVEL) -> None:
+ logger: Logger = getLogger(name)
+ handler: StreamHandler = StreamHandler(sys.stdout)
+ logger.setLevel(level)
+ handler.setLevel(level)
+ handler.setFormatter(PerLevelFormatter())
+ logger.addHandler(handler)
diff --git a/tests/conftest.py b/tests/conftest.py
index fcf4189..58416ea 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,30 @@
+import sys
+from typing import Callable
import pytest
+from logging import getLogger, DEBUG
+
from pyssg.arg_parser import get_parser
+from pyssg.custom_logger import setup_logger
@pytest.fixture(scope='session')
def arg_parser():
return get_parser()
+
+
+@pytest.fixture(scope='session')
+def logger():
+ setup_logger(__name__, DEBUG)
+ return getLogger(__name__)
+
+
+@pytest.fixture
+def capture_stdout(monkeypatch: Callable) -> dict[str, str | int]:
+ buffer: dict[str, str | int] = {'stdout': '', 'write_calls': 0}
+
+ def fake_writer(s):
+ buffer['stdout'] += s
+ buffer['write_calls'] += 1 # type: ignore
+
+ monkeypatch.setattr(sys.stdout, 'write', fake_writer)
+ return buffer
diff --git a/tests/io_files/__init__.py b/tests/io_files/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/io_files/__init__.py
diff --git a/tests/test_arg_parser.py b/tests/test_arg_parser.py
index 1a5b48c..e233e93 100644
--- a/tests/test_arg_parser.py
+++ b/tests/test_arg_parser.py
@@ -31,7 +31,6 @@ def test_valid_args(args: list[str],
@pytest.mark.parametrize('args', [
(['--something-random']),
(['-z']),
- (['hello']),
(['help']),
(['h'])
])
diff --git a/tests/test_custom_logger.py b/tests/test_custom_logger.py
new file mode 100644
index 0000000..1062e41
--- /dev/null
+++ b/tests/test_custom_logger.py
@@ -0,0 +1,20 @@
+import pytest
+from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL
+from typing import Callable
+
+
+@pytest.mark.parametrize('log_level, starts_with, message', [
+ (DEBUG, '[DEBUG]', 'first message'),
+ (INFO, 'second message', 'second message'),
+ (WARNING, '\x1b[33m[WARNING]', 'third message'),
+ (ERROR, '\x1b[31m[ERROR]', 'fourth message'),
+ (CRITICAL, '\x1b[31;1m[CRITICAL]', 'fifth message'),
+])
+def test_log_levels(log_level: int,
+ starts_with: str,
+ message: str,
+ logger: Callable,
+ capture_stdout: dict[str, str | int]) -> None:
+ logger.log(log_level, message)
+ assert str(capture_stdout['stdout']).startswith(starts_with)
+ assert message in str(capture_stdout['stdout'])