Source code for smif.data_layer.file.file_config_store

"""File-backed config store
"""
import copy
import os
from logging import getLogger

from ruamel.yaml import YAML  # type: ignore
from smif.data_layer.abstract_config_store import ConfigStore
from smif.data_layer.validate import (
    validate_sos_model_config,
    validate_sos_model_format,
)
from smif.exception import (
    SmifDataExistsError,
    SmifDataMismatchError,
    SmifDataNotFoundError,
)


[docs]class YamlConfigStore(ConfigStore): """Config backend saving to YAML configuration files. Arguments --------- base_folder: str The path to the configuration and data files """ def __init__(self, base_folder, validation=False): super().__init__() self.logger = getLogger(__name__) self.validation = validation self.base_folder = str(base_folder) self.config_folder = str(os.path.join(self.base_folder, "config")) self.config_folders = {} config_folders = [ "dimensions", "model_runs", "scenarios", "sector_models", "sos_models", ] for folder in config_folders: dirname = os.path.join(self.config_folder, folder) # ensure each directory exists if not os.path.exists(dirname): msg = "Expected configuration folder at '{}' but it does not exist" abs_path = os.path.abspath(dirname) raise SmifDataNotFoundError(msg.format(abs_path)) self.config_folders[folder] = dirname # cache results of reading project_config (invalidate on write) self._project_config_cache_invalid = True # MUST ONLY access through self.read_project_config() self._project_config_cache = None # ensure project config file exists try: self.read_project_config() except FileNotFoundError: # write empty config if none found self._write_project_config({})
[docs] def read_project_config(self): """Read the project configuration Returns ------- dict The project configuration """ if self._project_config_cache_invalid: self._project_config_cache = _read_yaml_file(self.base_folder, "project") self._project_config_cache_invalid = False return copy.deepcopy(self._project_config_cache)
def _write_project_config(self, data): """Write the project configuration Argument -------- data: dict The project configuration """ self._project_config_cache_invalid = True self._project_config_cache = None _write_yaml_file(self.base_folder, "project", data) def _read_config(self, config_type, config_name): """Read config item - used by decorators for existence/consistency checks""" if config_type == "scenario": return self.read_scenario(config_name) else: raise NotImplementedError( "Cannot read %s:%s through generic method." % (config_type, config_name) ) # region Model runs
[docs] def read_model_runs(self): names = _read_filenames_in_dir(self.config_folders["model_runs"], ".yml") model_runs = [self.read_model_run(name) for name in names] return model_runs
[docs] def read_model_run(self, model_run_name): _assert_file_exists(self.config_folders, "model_run", model_run_name) modelrun_config = self._read_model_run(model_run_name) del modelrun_config["strategies"] return modelrun_config
def _read_model_run(self, model_run_name): return _read_yaml_file(self.config_folders["model_runs"], model_run_name) def _overwrite_model_run(self, model_run_name, model_run): _write_yaml_file(self.config_folders["model_runs"], model_run_name, model_run)
[docs] def write_model_run(self, model_run): _assert_file_not_exists(self.config_folders, "model_run", model_run["name"]) config = copy.copy(model_run) config["strategies"] = [] _write_yaml_file(self.config_folders["model_runs"], config["name"], config)
[docs] def update_model_run(self, model_run_name, model_run): if model_run["name"] != model_run_name: raise SmifDataMismatchError( "Model run name '%s' must match '%s'" % (model_run_name, model_run["name"]) ) _assert_file_exists(self.config_folders, "model_run", model_run_name) prev = self._read_model_run(model_run_name) config = copy.copy(model_run) config["strategies"] = prev["strategies"] self._overwrite_model_run(model_run_name, config)
[docs] def delete_model_run(self, model_run_name): _assert_file_exists(self.config_folders, "model_run", model_run_name) os.remove( os.path.join(self.config_folders["model_runs"], model_run_name + ".yml") )
# endregion # region System-of-system models
[docs] def read_sos_models(self): names = _read_filenames_in_dir(self.config_folders["sos_models"], ".yml") sos_models = [self.read_sos_model(name) for name in names] return sos_models
[docs] def read_sos_model(self, sos_model_name): _assert_file_exists(self.config_folders, "sos_model", sos_model_name) data = _read_yaml_file(self.config_folders["sos_models"], sos_model_name) if self.validation: validate_sos_model_format(data) return data
[docs] def write_sos_model(self, sos_model): _assert_file_not_exists(self.config_folders, "sos_model", sos_model["name"]) _write_yaml_file( self.config_folders["sos_models"], sos_model["name"], sos_model )
[docs] def update_sos_model(self, sos_model_name, sos_model): if sos_model["name"] != sos_model_name: raise SmifDataMismatchError( "SoSModel name '%s' must match '%s'" % (sos_model_name, sos_model["name"]) ) _assert_file_exists(self.config_folders, "sos_model", sos_model_name) if self.validation: validate_sos_model_config( sos_model, self.read_models(), self.read_scenarios(), ) _write_yaml_file( self.config_folders["sos_models"], sos_model["name"], sos_model )
[docs] def delete_sos_model(self, sos_model_name): _assert_file_exists(self.config_folders, "sos_model", sos_model_name) os.remove( os.path.join(self.config_folders["sos_models"], sos_model_name + ".yml") )
# endregion # region Models
[docs] def read_models(self): names = _read_filenames_in_dir(self.config_folders["sector_models"], ".yml") models = [self.read_model(name) for name in names] return models
[docs] def read_model(self, model_name): _assert_file_exists(self.config_folders, "sector_model", model_name) model = _read_yaml_file(self.config_folders["sector_models"], model_name) return model
[docs] def write_model(self, model): _assert_file_not_exists(self.config_folders, "sector_model", model["name"]) model = copy.deepcopy(model) if model["interventions"]: self.logger.warning("Ignoring interventions") model["interventions"] = [] model = _skip_coords(model, ("inputs", "outputs", "parameters")) _write_yaml_file(self.config_folders["sector_models"], model["name"], model)
[docs] def update_model(self, model_name, model): if model["name"] != model_name: raise SmifDataMismatchError( "Model name '%s' must match '%s'" % (model_name, model["name"]) ) _assert_file_exists(self.config_folders, "sector_model", model_name) model = copy.deepcopy(model) # ignore interventions and initial conditions which the app doesn't handle if model["interventions"] or model["initial_conditions"]: old_model = _read_yaml_file( self.config_folders["sector_models"], model["name"] ) if model["interventions"]: self.logger.warning("Ignoring interventions write") model["interventions"] = old_model["interventions"] if model["initial_conditions"]: self.logger.warning("Ignoring initial conditions write") model["initial_conditions"] = old_model["initial_conditions"] model = _skip_coords(model, ("inputs", "outputs", "parameters")) _write_yaml_file(self.config_folders["sector_models"], model["name"], model)
[docs] def delete_model(self, model_name): _assert_file_exists(self.config_folders, "sector_model", model_name) os.remove( os.path.join(self.config_folders["sector_models"], model_name + ".yml") )
# endregion # region Scenarios
[docs] def read_scenarios(self): scenario_names = _read_filenames_in_dir( self.config_folders["scenarios"], ".yml" ) return [self.read_scenario(name) for name in scenario_names]
[docs] def read_scenario(self, scenario_name): _assert_file_exists(self.config_folders, "scenario", scenario_name) scenario = _read_yaml_file(self.config_folders["scenarios"], scenario_name) return scenario
[docs] def write_scenario(self, scenario): _assert_file_not_exists(self.config_folders, "scenario", scenario["name"]) scenario = _skip_coords(scenario, ["provides"]) _write_yaml_file(self.config_folders["scenarios"], scenario["name"], scenario)
[docs] def update_scenario(self, scenario_name, scenario): _assert_file_exists(self.config_folders, "scenario", scenario_name) scenario = _skip_coords(scenario, ["provides"]) _write_yaml_file(self.config_folders["scenarios"], scenario["name"], scenario)
[docs] def delete_scenario(self, scenario_name): _assert_file_exists(self.config_folders, "scenario", scenario_name) os.remove( os.path.join( self.config_folders["scenarios"], "{}.yml".format(scenario_name) ) )
# endregion # region Scenario Variants
[docs] def read_scenario_variants(self, scenario_name): scenario = self.read_scenario(scenario_name) return scenario["variants"]
[docs] def read_scenario_variant(self, scenario_name, variant_name): variants = self.read_scenario_variants(scenario_name) return _pick_from_list(variants, variant_name)
[docs] def write_scenario_variant(self, scenario_name, variant): scenario = self.read_scenario(scenario_name) scenario["variants"].append(variant) self.update_scenario(scenario_name, scenario)
[docs] def update_scenario_variant(self, scenario_name, variant_name, variant): scenario = self.read_scenario(scenario_name) v_idx = _idx_in_list(scenario["variants"], variant_name) scenario["variants"][v_idx] = variant self.update_scenario(scenario_name, scenario)
[docs] def delete_scenario_variant(self, scenario_name, variant_name): scenario = self.read_scenario(scenario_name) v_idx = _idx_in_list(scenario["variants"], variant_name) del scenario["variants"][v_idx] self.update_scenario(scenario_name, scenario)
# endregion # region Narratives
[docs] def read_narrative(self, sos_model_name, narrative_name): sos_model = self.read_sos_model(sos_model_name) narrative = _pick_from_list(sos_model["narratives"], narrative_name) if not narrative: msg = "Narrative '{}' not found in '{}'" raise SmifDataNotFoundError(msg.format(narrative_name, sos_model_name)) return narrative
# endregion # region Strategies
[docs] def read_strategies(self, modelrun_name): model_run_config = self._read_model_run(modelrun_name) return model_run_config["strategies"]
[docs] def write_strategies(self, modelrun_name, strategies): model_run = self._read_model_run(modelrun_name) model_run["strategies"] = strategies self._overwrite_model_run(modelrun_name, model_run)
# endregion def _read_yaml_file(directory, name): """Read yaml config file into plain data (lists, dicts and simple values) Parameters ---------- directory : str name : str """ path = os.path.join(directory, "{}.yml".format(name)) with open(path, "r") as file_handle: return YAML().load(file_handle) def _write_yaml_file(directory, name, data): """Write plain data to a file as yaml Arguments --------- directory: str Path to directory name: str Name of config item (filename without .yml extension) data Data to be written to the file """ path = os.path.join(directory, "{}.yml".format(name)) with open(path, "w") as file_handle: yaml = YAML() yaml.default_flow_style = False yaml.allow_unicode = True return yaml.dump(data, file_handle) def _assert_file_exists(file_dir, dtype, name): if not _file_exists(file_dir, dtype, name): raise SmifDataNotFoundError("%s '%s' not found" % (dtype.capitalize(), name)) def _assert_file_not_exists(file_dir, dtype, name): if _file_exists(file_dir, dtype, name): raise SmifDataExistsError("%s '%s' already exists" % (dtype.capitalize(), name)) def _file_exists(file_dir, dtype, name): dir_key = "%ss" % dtype try: return os.path.exists(os.path.join(file_dir[dir_key], name + ".yml")) except TypeError: msg = "Could not parse file name {} and dtype {}" raise SmifDataNotFoundError(msg.format(name, dtype)) def _read_filenames_in_dir(path, extension): """Returns the name of the Yaml files in a certain directory Arguments --------- path: str Path to directory extension: str Extension of files (such as: '.yml' or '.csv') Returns ------- list The list of files in `path` with extension """ files = [] for filename in os.listdir(path): if filename.endswith(extension): files.append(os.path.splitext(filename)[0]) return files def _skip_coords(config, keys): """Given a config dict and list of top-level keys for lists of specs, delete coords from each spec in each list. """ config = copy.deepcopy(config) for key in keys: for spec in config[key]: try: del spec["coords"] except KeyError: pass return config def _pick_from_list(list_of_dicts, name): for item in list_of_dicts: if "name" in item and item["name"] == name: return item return None def _idx_in_list(list_of_dicts, name): for i, item in enumerate(list_of_dicts): if "name" in item and item["name"] == name: return i return None