diff --git a/ymmsl/v0_2/configuration.py b/ymmsl/v0_2/configuration.py index 9638a12..0e36431 100644 --- a/ymmsl/v0_2/configuration.py +++ b/ymmsl/v0_2/configuration.py @@ -305,9 +305,11 @@ def _component_paths(self) -> Dict[Reference, Component]: result = dict() queue: List[Tuple[Model, Reference, List[Tuple[Reference, Reference]]]] = \ [(m, Reference([]), []) for m in self._root_models()] + _logger.debug(f'cmp_paths: initial queue: {[t[0].name for t in queue]}') while queue: model, prefix, seen = queue.pop(0) + _logger.debug(f'cmp_paths: {model.name} {prefix} {seen}') for component in model.components.values(): path = prefix + component.name impl = self.custom_implementations.get(path, component.implementation) @@ -539,7 +541,9 @@ def _check_resources( """ errors = list() for path, component in component_paths.items(): + _logger.debug(f'Checking resources for {path} {component.name}') impl_ref = self.custom_implementations.get(path, component.implementation) + _logger.debug(f'Implementation: {impl_ref}') if impl_ref is None or impl_ref not in self.programs: continue diff --git a/ymmsl/v0_2/resolver.py b/ymmsl/v0_2/resolver.py index dbb93ed..4695b81 100644 --- a/ymmsl/v0_2/resolver.py +++ b/ymmsl/v0_2/resolver.py @@ -108,16 +108,8 @@ def resolve(module: Reference, config: Configuration) -> None: RuntimeError: if an error occurs due to an invalid configuration. This will leave config in a broken state, so reload it if you want to try again. """ - overwritten_implementations = do_resolve( - Path('
'), module, config, ResolutionContext()) - - used_implementations = { - c.implementation for m in config.models.values() - for c in m.components.values()} - - for m in list(config.models.keys()): - if m in overwritten_implementations and m not in used_implementations: - del config.models[m] + overwritten = do_resolve(Path('
'), module, config, ResolutionContext()) + remove_overwritten_implementations(config, overwritten) def do_resolve( @@ -134,9 +126,11 @@ def do_resolve( module: The module corresponding to this configuration config: The configuration to resolve """ + _logger.debug(f'Resolving {module}') ctx.push_module(file, module) overwritten_implementations = resolve_impls(module, config, ctx) ctx.pop_module() + _logger.debug(f'Done resolving {module}') return overwritten_implementations @@ -158,10 +152,10 @@ def resolve_impls( rename_local_impls(config.programs, module, ylocals) rename_local_impls(config.models, module, ylocals) resolve_impl_imports(config, ylocals, ctx) + update_local_implementations(config, ylocals) config.imports = [i for i in config.imports if i.kind != ImportKind.IMPLEMENTATION] overwritten_impls = apply_custom_implementations(config, module, ylocals, ctx) - update_local_implementations(config, ylocals) return overwritten_impls @@ -200,6 +194,7 @@ def resolve_impl_imports( """ for imp_st in config.imports: ctx.push_import(imp_st) + _logger.debug(f'Processing import {imp_st.module} implementation {imp_st.name}') if imp_st.kind == ImportKind.IMPLEMENTATION: imp_cfg, loaded_file = load_resolve_module( imp_st.module, imp_st.module_path(), ctx) @@ -221,6 +216,7 @@ def resolve_impl_imports( raise RuntimeError(msg) ylocals[Reference([imp_st.name])] = imp_st.full_name() + _logger.debug(f'Imported {imp_st.full_name()} as {imp_st.name}') ctx.pop_module() ctx.pop_import() @@ -259,6 +255,19 @@ def impl_hint_msg( overwritten_implementations = set() copied_paths = set() + def set_overwritten(config: Configuration, implementation: Reference) -> None: + seen = set() + queue = [implementation] + while queue: + impl = queue.pop(0) + seen.add(impl) + overwritten_implementations.add(impl) + if impl in config.models: + queue.extend([ + c.implementation + for c in config.models[impl].components.values() + if c.implementation is not None and c.implementation not in seen]) + # Pre-copy any models that will be updated, if they were imported and we therefore # cannot modify them in place without interfering with other uses of the same model. for key, value in config.custom_implementations.items(): @@ -288,13 +297,15 @@ def impl_hint_msg( config.models[new_name] = m ylocals[base_model_name] = m.name - overwritten_implementations.add(orig_name) + set_overwritten(config, orig_name) for key, value in config.custom_implementations.items(): base_model_name = Reference([key[0]]) path = key[1:] new_impl = ylocals[value] if value is not None else None + _logger.debug( + f'Processing custom implementation {base_model_name} {path} {new_impl}') m = config.models[ylocals[base_model_name]] # Now we can walk down the components and copy-and-rename the models along the @@ -335,7 +346,8 @@ def impl_hint_msg( config.models[new_name] = new_submodel m.components[component] = copy(m.components[component]) m.components[component].implementation = new_name - overwritten_implementations.add(orig_impl) + _logger.debug(f'Set {new_name} as impl of {m.name} {component}') + set_overwritten(config, orig_impl) m = new_submodel copied_paths.add(base_model_name + path[:i+1]) else: @@ -359,10 +371,12 @@ def impl_hint_msg( component = Reference([path[-1]]) new_component = copy(m.components[component]) + _logger.debug(f'In {m.name} replacing {component} with {new_impl}') if new_component.implementation is not None: if new_component.implementation in config.models: - overwritten_implementations.add(new_component.implementation) - new_component.implementation = new_impl + set_overwritten(config, new_component.implementation) + # copy to avoid YAML reference when serialising + new_component.implementation = copy(new_impl) m.components[component] = new_component config.custom_implementations.clear() @@ -372,11 +386,53 @@ def impl_hint_msg( def update_local_implementations( config: Configuration, ylocals: Dict[Reference, Reference]) -> None: """Updates names of local implementations to their full names.""" + _logger.debug('Updating local implementations') for model in config.models.values(): + _logger.debug(f'Updating impls in model {model.name}') for cmp in model.components.values(): if cmp.implementation: if cmp.implementation in ylocals: - cmp.implementation = ylocals[cmp.implementation] + # copy to avoid YAML reference when serialising + cmp.implementation = copy(ylocals[cmp.implementation]) + + +def remove_overwritten_implementations( + config: Configuration, overwritten: Set[Reference]) -> None: + """Remove implementations that are no longer needed. + + This can happen when a custom_implementation overwrites the implementation of a + component, resulting in the original implementation and anything it depends on being + unused, unless of course they are used somewhere else in the model. + + Args: + config: Configuration to update + overwritten: Set of references to implementations that were overwritten by a new + reference to a different implementation. + """ + referred_to = { + c.implementation for m in config.models.values() + for c in m.components.values()} + + roots = set(config.models.keys()) - referred_to - overwritten + used = set() + queue = list(roots) + seen = set() + while queue: + cur_impl = queue.pop(0) + seen.add(cur_impl) + used.add(cur_impl) + if cur_impl in config.models: + queue.extend([ + c.implementation + for c in config.models[cur_impl].components.values() + if c.implementation is not None + and c.implementation not in seen + and c.implementation in config.models + ]) + + for m in list(config.models.keys()): + if m in overwritten and m not in used: + del config.models[m] def find_impls( diff --git a/ymmsl/v0_2/tests/test_resolver.py b/ymmsl/v0_2/tests/test_resolver.py index 2b2bcd7..c371e18 100644 --- a/ymmsl/v0_2/tests/test_resolver.py +++ b/ymmsl/v0_2/tests/test_resolver.py @@ -332,6 +332,42 @@ def test_apply_custom_implementations_everything_localised() -> None: assert b.components[Ref('macro')].implementation == 'el.p' +def test_apply_custom_implementations_cut_branch(env_ymmsl_path: None) -> None: + ymmsl = ( + 'ymmsl_version: v0.2\n' + 'description: |\n' + ' Testing that lopped-off branches are cleaned up properly.\n' + 'imports:\n' + '- from a.g import implementation test_macro_micro\n' + 'models:\n' + ' test_deeply_nested:\n' + ' description: Models within models within models...\n' + ' components:\n' + ' c1:\n' + ' ports:\n' + ' o_i: out\n' + ' s: in\n' + ' description: Macro model\n' + ' implementation: test_macro_micro\n' + 'programs:\n' + ' program3:\n' + ' ports:\n' + ' o_i: out\n' + ' s: in\n' + ' description: Alternative implementation\n' + ' executable: python3\n' + ' args: /home/user/program3.py\n' + 'custom_implementations:\n' + ' test_deeply_nested.c1: program3\n' + ) + + config = load(ymmsl) + assert isinstance(config, Configuration) + + resolve(Reference('nested'), config) + assert len(config.models) == 1 + + def test_apply_custom_implementations_errors(env_ymmsl_path: None) -> None: ymmsl = ( 'ymmsl_version: v0.2\n'