From 28c2ae9102d4204b3f0a79419eec1e72dbbc529a Mon Sep 17 00:00:00 2001
From: David Luevano Alvarado <david@luevano.xyz>
Date: Tue, 21 Feb 2023 21:02:23 -0600
Subject: add configuration testing, small refactor

---
 src/pyssg/configuration.py | 39 +++++++++++++++++++++++++++------------
 src/pyssg/custom_logger.py |  3 ++-
 src/pyssg/yaml_parser.py   |  7 ++++---
 3 files changed, 33 insertions(+), 16 deletions(-)

(limited to 'src')

diff --git a/src/pyssg/configuration.py b/src/pyssg/configuration.py
index 7b292d5..3cc5430 100644
--- a/src/pyssg/configuration.py
+++ b/src/pyssg/configuration.py
@@ -2,6 +2,7 @@ import sys
 from importlib.metadata import version
 from datetime import datetime, timezone
 from logging import Logger, getLogger
+from typing import Any
 
 from .utils import get_expanded_path
 from .yaml_parser import get_parsed_yaml
@@ -23,15 +24,23 @@ def __check_well_formed_config(config: dict,
             sys.exit(1)
         # checks for dir_paths
         if key == 'dirs':
+            try:
+                config[key].keys()
+            except AttributeError:
+                log.error('config doesn\'t have any dirs configs (dirs.*)')
+                sys.exit(1)
             if '/' not in config[key]:
-                log.error('config doesn\'t have "%s./"', current_key)
                 log.debug('key: %s; config.keys: %s', key, config[key].keys())
+                log.error('config doesn\'t have "%s./"', current_key)
                 sys.exit(1)
-            log.debug('checking "%s" fields for (%s) dir_paths', key, ', '.join(config[key].keys()))
+            log.debug('checking "%s" fields for (%s) dir_paths',
+                      key, ', '.join(config[key].keys()))
             for dkey in config[key].keys():
                 new_current_key: str = f'{current_key}.{dkey}'
                 new_config_base: list[dict] = [config_base[1], config_base[1]]
-                __check_well_formed_config(config[key][dkey], new_config_base, new_current_key)
+                __check_well_formed_config(config[key][dkey],
+                                           new_config_base,
+                                           new_current_key)
             continue
         # the case for elements that don't have nested elements
         if not config_base[0][key]:
@@ -48,12 +57,14 @@ def __expand_all_paths(config: dict) -> None:
 
 
 # not necessary to type deeper than the first dict
-def get_parsed_config(path: str) -> list[dict]:
+def get_parsed_config(path: str,
+                      mc_package: str = 'mandatory_config.yaml',
+                      plt_resource: str = 'pyssg.plt') -> list[dict]:
     log.debug('reading config file "%s"', path)
     config_all: list[dict] = get_parsed_yaml(path)
-    mandatory_config: list[dict] = get_parsed_yaml('mandatory_config.yaml', 'pyssg.plt')
-    log.info('found %s document(s) for configuration "%s"', len(config_all), path)
-    log.debug('checking that config file is well formed (at least contains mandatory fields')
+    mandatory_config: list[dict] = get_parsed_yaml(mc_package, plt_resource)
+    log.info('found %s document(s) for config "%s"', len(config_all), path)
+    log.debug('checking that config file is well formed')
     for config in config_all:
         __check_well_formed_config(config, mandatory_config)
         __expand_all_paths(config)
@@ -62,11 +73,15 @@ def get_parsed_config(path: str) -> list[dict]:
 
 # not necessary to type deeper than the first dict,
 #   static config means config that shouldn't be changed by the user
-def get_static_config() -> dict[str, dict]:
+def get_static_config(sc_package: str = 'static_config.yaml',
+                      plt_resource: str = 'pyssg.plt') -> dict[str, dict]:
     log.debug('reading and setting static config')
-    config: dict = get_parsed_yaml('static_config.yaml', 'pyssg.plt')[0]
-    # do I really need a lambda function...
+    config: dict[str, Any] = get_parsed_yaml(sc_package, plt_resource)[0]
+
+    def __time(fmt: str) -> str:
+        return datetime.now(tz=timezone.utc).strftime(config['fmt'][fmt])
+
     config['info']['version'] = VERSION
-    config['info']['rss_run_date'] = datetime.now(tz=timezone.utc).strftime(config['fmt']['rss_date'])
-    config['info']['sitemap_run_date'] = datetime.now(tz=timezone.utc).strftime(config['fmt']['sitemap_date'])
+    config['info']['rss_run_date'] = __time('rss_date')
+    config['info']['sitemap_run_date'] = __time('sitemap_date')
     return config
diff --git a/src/pyssg/custom_logger.py b/src/pyssg/custom_logger.py
index 55f3e2d..4eebc4c 100644
--- a/src/pyssg/custom_logger.py
+++ b/src/pyssg/custom_logger.py
@@ -2,6 +2,7 @@ import sys
 from logging import (Logger, StreamHandler, Formatter, LogRecord,
                      DEBUG, INFO, WARNING, ERROR, CRITICAL,
                      getLogger)
+from typing import TextIO
 
 LOG_LEVEL: int = INFO
 # 'pyssg' es the name of the root logger
@@ -38,7 +39,7 @@ class PerLevelFormatter(Formatter):
 
 def setup_logger(name: str = LOGGER_NAME, level: int = LOG_LEVEL) -> None:
     logger: Logger = getLogger(name)
-    handler: StreamHandler = StreamHandler(sys.stdout)
+    handler: StreamHandler[TextIO] = StreamHandler(sys.stdout)
     logger.setLevel(level)
     handler.setLevel(level)
     handler.setFormatter(PerLevelFormatter())
diff --git a/src/pyssg/yaml_parser.py b/src/pyssg/yaml_parser.py
index 2e1548b..aeb164e 100644
--- a/src/pyssg/yaml_parser.py
+++ b/src/pyssg/yaml_parser.py
@@ -3,6 +3,7 @@ from yaml import SafeLoader
 from yaml.nodes import SequenceNode
 from importlib.resources import path as rpath
 from logging import Logger, getLogger
+from typing import Any
 
 log: Logger = getLogger(__name__)
 
@@ -17,15 +18,15 @@ def setup_custom_yaml() -> None:
     SafeLoader.add_constructor('!join', __join_constructor)
 
 
-def __read_raw_yaml(path: str) -> list[dict]:
-    all_docs: list[dict] = []
+def __read_raw_yaml(path: str) -> list[dict[str, Any]]:
+    all_docs: list[dict[str, Any]] = []
     with open(path, 'r') as f:
         for doc in yaml.safe_load_all(f):
             all_docs.append(doc)
     return all_docs
 
 
-def get_parsed_yaml(resource: str, package: str = '') -> list[dict]:
+def get_parsed_yaml(resource: str, package: str = '') -> list[dict[str, Any]]:
     if package == '':
         log.debug('parsing yaml; reading "%s"', resource)
         return __read_raw_yaml(resource)
-- 
cgit v1.2.3-70-g09d2