"""File-backed config store
"""
import copy
import os
from logging import getLogger
from ruamel.yaml import YAML # type: ignore
from smif.data_layer.abstract_config_store import ConfigStore
from smif.data_layer.validate import (
validate_sos_model_config,
validate_sos_model_format,
)
from smif.exception import (
SmifDataExistsError,
SmifDataMismatchError,
SmifDataNotFoundError,
)
[docs]class YamlConfigStore(ConfigStore):
"""Config backend saving to YAML configuration files.
Arguments
---------
base_folder: str
The path to the configuration and data files
"""
def __init__(self, base_folder, validation=False):
super().__init__()
self.logger = getLogger(__name__)
self.validation = validation
self.base_folder = str(base_folder)
self.config_folder = str(os.path.join(self.base_folder, "config"))
self.config_folders = {}
config_folders = [
"dimensions",
"model_runs",
"scenarios",
"sector_models",
"sos_models",
]
for folder in config_folders:
dirname = os.path.join(self.config_folder, folder)
# ensure each directory exists
if not os.path.exists(dirname):
msg = "Expected configuration folder at '{}' but it does not exist"
abs_path = os.path.abspath(dirname)
raise SmifDataNotFoundError(msg.format(abs_path))
self.config_folders[folder] = dirname
# cache results of reading project_config (invalidate on write)
self._project_config_cache_invalid = True
# MUST ONLY access through self.read_project_config()
self._project_config_cache = None
# ensure project config file exists
try:
self.read_project_config()
except FileNotFoundError:
# write empty config if none found
self._write_project_config({})
[docs] def read_project_config(self):
"""Read the project configuration
Returns
-------
dict
The project configuration
"""
if self._project_config_cache_invalid:
self._project_config_cache = _read_yaml_file(self.base_folder, "project")
self._project_config_cache_invalid = False
return copy.deepcopy(self._project_config_cache)
def _write_project_config(self, data):
"""Write the project configuration
Argument
--------
data: dict
The project configuration
"""
self._project_config_cache_invalid = True
self._project_config_cache = None
_write_yaml_file(self.base_folder, "project", data)
def _read_config(self, config_type, config_name):
"""Read config item - used by decorators for existence/consistency checks"""
if config_type == "scenario":
return self.read_scenario(config_name)
else:
raise NotImplementedError(
"Cannot read %s:%s through generic method." % (config_type, config_name)
)
# region Model runs
[docs] def read_model_runs(self):
names = _read_filenames_in_dir(self.config_folders["model_runs"], ".yml")
model_runs = [self.read_model_run(name) for name in names]
return model_runs
[docs] def read_model_run(self, model_run_name):
_assert_file_exists(self.config_folders, "model_run", model_run_name)
modelrun_config = self._read_model_run(model_run_name)
del modelrun_config["strategies"]
return modelrun_config
def _read_model_run(self, model_run_name):
return _read_yaml_file(self.config_folders["model_runs"], model_run_name)
def _overwrite_model_run(self, model_run_name, model_run):
_write_yaml_file(self.config_folders["model_runs"], model_run_name, model_run)
[docs] def write_model_run(self, model_run):
_assert_file_not_exists(self.config_folders, "model_run", model_run["name"])
config = copy.copy(model_run)
config["strategies"] = []
_write_yaml_file(self.config_folders["model_runs"], config["name"], config)
[docs] def update_model_run(self, model_run_name, model_run):
if model_run["name"] != model_run_name:
raise SmifDataMismatchError(
"Model run name '%s' must match '%s'"
% (model_run_name, model_run["name"])
)
_assert_file_exists(self.config_folders, "model_run", model_run_name)
prev = self._read_model_run(model_run_name)
config = copy.copy(model_run)
config["strategies"] = prev["strategies"]
self._overwrite_model_run(model_run_name, config)
[docs] def delete_model_run(self, model_run_name):
_assert_file_exists(self.config_folders, "model_run", model_run_name)
os.remove(
os.path.join(self.config_folders["model_runs"], model_run_name + ".yml")
)
# endregion
# region System-of-system models
[docs] def read_sos_models(self):
names = _read_filenames_in_dir(self.config_folders["sos_models"], ".yml")
sos_models = [self.read_sos_model(name) for name in names]
return sos_models
[docs] def read_sos_model(self, sos_model_name):
_assert_file_exists(self.config_folders, "sos_model", sos_model_name)
data = _read_yaml_file(self.config_folders["sos_models"], sos_model_name)
if self.validation:
validate_sos_model_format(data)
return data
[docs] def write_sos_model(self, sos_model):
_assert_file_not_exists(self.config_folders, "sos_model", sos_model["name"])
_write_yaml_file(
self.config_folders["sos_models"], sos_model["name"], sos_model
)
[docs] def update_sos_model(self, sos_model_name, sos_model):
if sos_model["name"] != sos_model_name:
raise SmifDataMismatchError(
"SoSModel name '%s' must match '%s'"
% (sos_model_name, sos_model["name"])
)
_assert_file_exists(self.config_folders, "sos_model", sos_model_name)
if self.validation:
validate_sos_model_config(
sos_model,
self.read_models(),
self.read_scenarios(),
)
_write_yaml_file(
self.config_folders["sos_models"], sos_model["name"], sos_model
)
[docs] def delete_sos_model(self, sos_model_name):
_assert_file_exists(self.config_folders, "sos_model", sos_model_name)
os.remove(
os.path.join(self.config_folders["sos_models"], sos_model_name + ".yml")
)
# endregion
# region Models
[docs] def read_models(self):
names = _read_filenames_in_dir(self.config_folders["sector_models"], ".yml")
models = [self.read_model(name) for name in names]
return models
[docs] def read_model(self, model_name):
_assert_file_exists(self.config_folders, "sector_model", model_name)
model = _read_yaml_file(self.config_folders["sector_models"], model_name)
return model
[docs] def write_model(self, model):
_assert_file_not_exists(self.config_folders, "sector_model", model["name"])
model = copy.deepcopy(model)
if model["interventions"]:
self.logger.warning("Ignoring interventions")
model["interventions"] = []
model = _skip_coords(model, ("inputs", "outputs", "parameters"))
_write_yaml_file(self.config_folders["sector_models"], model["name"], model)
[docs] def update_model(self, model_name, model):
if model["name"] != model_name:
raise SmifDataMismatchError(
"Model name '%s' must match '%s'" % (model_name, model["name"])
)
_assert_file_exists(self.config_folders, "sector_model", model_name)
model = copy.deepcopy(model)
# ignore interventions and initial conditions which the app doesn't handle
if model["interventions"] or model["initial_conditions"]:
old_model = _read_yaml_file(
self.config_folders["sector_models"], model["name"]
)
if model["interventions"]:
self.logger.warning("Ignoring interventions write")
model["interventions"] = old_model["interventions"]
if model["initial_conditions"]:
self.logger.warning("Ignoring initial conditions write")
model["initial_conditions"] = old_model["initial_conditions"]
model = _skip_coords(model, ("inputs", "outputs", "parameters"))
_write_yaml_file(self.config_folders["sector_models"], model["name"], model)
[docs] def delete_model(self, model_name):
_assert_file_exists(self.config_folders, "sector_model", model_name)
os.remove(
os.path.join(self.config_folders["sector_models"], model_name + ".yml")
)
# endregion
# region Scenarios
[docs] def read_scenarios(self):
scenario_names = _read_filenames_in_dir(
self.config_folders["scenarios"], ".yml"
)
return [self.read_scenario(name) for name in scenario_names]
[docs] def read_scenario(self, scenario_name):
_assert_file_exists(self.config_folders, "scenario", scenario_name)
scenario = _read_yaml_file(self.config_folders["scenarios"], scenario_name)
return scenario
[docs] def write_scenario(self, scenario):
_assert_file_not_exists(self.config_folders, "scenario", scenario["name"])
scenario = _skip_coords(scenario, ["provides"])
_write_yaml_file(self.config_folders["scenarios"], scenario["name"], scenario)
[docs] def update_scenario(self, scenario_name, scenario):
_assert_file_exists(self.config_folders, "scenario", scenario_name)
scenario = _skip_coords(scenario, ["provides"])
_write_yaml_file(self.config_folders["scenarios"], scenario["name"], scenario)
[docs] def delete_scenario(self, scenario_name):
_assert_file_exists(self.config_folders, "scenario", scenario_name)
os.remove(
os.path.join(
self.config_folders["scenarios"], "{}.yml".format(scenario_name)
)
)
# endregion
# region Scenario Variants
[docs] def read_scenario_variants(self, scenario_name):
scenario = self.read_scenario(scenario_name)
return scenario["variants"]
[docs] def read_scenario_variant(self, scenario_name, variant_name):
variants = self.read_scenario_variants(scenario_name)
return _pick_from_list(variants, variant_name)
[docs] def write_scenario_variant(self, scenario_name, variant):
scenario = self.read_scenario(scenario_name)
scenario["variants"].append(variant)
self.update_scenario(scenario_name, scenario)
[docs] def update_scenario_variant(self, scenario_name, variant_name, variant):
scenario = self.read_scenario(scenario_name)
v_idx = _idx_in_list(scenario["variants"], variant_name)
scenario["variants"][v_idx] = variant
self.update_scenario(scenario_name, scenario)
[docs] def delete_scenario_variant(self, scenario_name, variant_name):
scenario = self.read_scenario(scenario_name)
v_idx = _idx_in_list(scenario["variants"], variant_name)
del scenario["variants"][v_idx]
self.update_scenario(scenario_name, scenario)
# endregion
# region Narratives
[docs] def read_narrative(self, sos_model_name, narrative_name):
sos_model = self.read_sos_model(sos_model_name)
narrative = _pick_from_list(sos_model["narratives"], narrative_name)
if not narrative:
msg = "Narrative '{}' not found in '{}'"
raise SmifDataNotFoundError(msg.format(narrative_name, sos_model_name))
return narrative
# endregion
# region Strategies
[docs] def read_strategies(self, modelrun_name):
model_run_config = self._read_model_run(modelrun_name)
return model_run_config["strategies"]
[docs] def write_strategies(self, modelrun_name, strategies):
model_run = self._read_model_run(modelrun_name)
model_run["strategies"] = strategies
self._overwrite_model_run(modelrun_name, model_run)
# endregion
def _read_yaml_file(directory, name):
"""Read yaml config file into plain data (lists, dicts and simple values)
Parameters
----------
directory : str
name : str
"""
path = os.path.join(directory, "{}.yml".format(name))
with open(path, "r") as file_handle:
return YAML().load(file_handle)
def _write_yaml_file(directory, name, data):
"""Write plain data to a file as yaml
Arguments
---------
directory: str
Path to directory
name: str
Name of config item (filename without .yml extension)
data
Data to be written to the file
"""
path = os.path.join(directory, "{}.yml".format(name))
with open(path, "w") as file_handle:
yaml = YAML()
yaml.default_flow_style = False
yaml.allow_unicode = True
return yaml.dump(data, file_handle)
def _assert_file_exists(file_dir, dtype, name):
if not _file_exists(file_dir, dtype, name):
raise SmifDataNotFoundError("%s '%s' not found" % (dtype.capitalize(), name))
def _assert_file_not_exists(file_dir, dtype, name):
if _file_exists(file_dir, dtype, name):
raise SmifDataExistsError("%s '%s' already exists" % (dtype.capitalize(), name))
def _file_exists(file_dir, dtype, name):
dir_key = "%ss" % dtype
try:
return os.path.exists(os.path.join(file_dir[dir_key], name + ".yml"))
except TypeError:
msg = "Could not parse file name {} and dtype {}"
raise SmifDataNotFoundError(msg.format(name, dtype))
def _read_filenames_in_dir(path, extension):
"""Returns the name of the Yaml files in a certain directory
Arguments
---------
path: str
Path to directory
extension: str
Extension of files (such as: '.yml' or '.csv')
Returns
-------
list
The list of files in `path` with extension
"""
files = []
for filename in os.listdir(path):
if filename.endswith(extension):
files.append(os.path.splitext(filename)[0])
return files
def _skip_coords(config, keys):
"""Given a config dict and list of top-level keys for lists of specs,
delete coords from each spec in each list.
"""
config = copy.deepcopy(config)
for key in keys:
for spec in config[key]:
try:
del spec["coords"]
except KeyError:
pass
return config
def _pick_from_list(list_of_dicts, name):
for item in list_of_dicts:
if "name" in item and item["name"] == name:
return item
return None
def _idx_in_list(list_of_dicts, name):
for i, item in enumerate(list_of_dicts):
if "name" in item and item["name"] == name:
return i
return None