diff options
-rw-r--r-- | src/buildstream/_includes.py | 41 | ||||
-rw-r--r-- | tests/format/include.py | 4 |
2 files changed, 19 insertions, 26 deletions
diff --git a/src/buildstream/_includes.py b/src/buildstream/_includes.py index 0c77e5fa1..e4b144337 100644 --- a/src/buildstream/_includes.py +++ b/src/buildstream/_includes.py @@ -78,39 +78,24 @@ class Includes: if includes_node: if type(includes_node) is ScalarNode: # pylint: disable=unidiomatic-typecheck - includes = [includes_node.as_str()] + includes = [includes_node] else: - includes = includes_node.as_str_list() + includes = includes_node del node["(@)"] for include in reversed(includes): - if only_local and ":" in include: + if only_local and ":" in include.as_str(): continue - try: - include_node, file_path, sub_loader = self._include_file(include, current_loader) - except LoadError as e: - include_provenance = includes_node.get_provenance() - if e.reason == LoadErrorReason.MISSING_FILE: - message = "{}: Include block references a file that could not be found: '{}'.".format( - include_provenance, include - ) - raise LoadError(message, LoadErrorReason.MISSING_FILE) from e - if e.reason == LoadErrorReason.LOADING_DIRECTORY: - message = "{}: Include block references a directory instead of a file: '{}'.".format( - include_provenance, include - ) - raise LoadError(message, LoadErrorReason.LOADING_DIRECTORY) from e - - # Otherwise, we don't know the reason, so just raise - raise + include_node, file_path, sub_loader = self._include_file(include, current_loader) if file_path in included: include_provenance = includes_node.get_provenance() raise LoadError( "{}: trying to recursively include {}".format(include_provenance, file_path), LoadErrorReason.RECURSIVE_INCLUDE, ) + # Because the included node will be modified, we need # to copy it so that we do not modify the toplevel # node of the provenance. @@ -144,14 +129,16 @@ class Includes: # Load include YAML file from with a loader. # # Args: - # include (str): file path relative to loader's project directory. - # Can be prefixed with junctio name. + # include (ScalarNode): file path relative to loader's project directory. + # Can be prefixed with junctio name. # loader (Loader): Loader for the current project. def _include_file(self, include, loader): + provenance = include.get_provenance() + include = include.as_str() shortname = include if ":" in include: junction, include = include.rsplit(":", 1) - current_loader = loader.get_loader(junction) + current_loader = loader.get_loader(junction, provenance=provenance) current_loader.project.ensure_fully_loaded() else: current_loader = loader @@ -160,7 +147,13 @@ class Includes: file_path = os.path.join(directory, include) key = (current_loader, file_path) if key not in self._loaded: - self._loaded[key] = _yaml.load(file_path, shortname=shortname, project=project, copy_tree=self._copy_tree) + try: + self._loaded[key] = _yaml.load( + file_path, shortname=shortname, project=project, copy_tree=self._copy_tree + ) + except LoadError as e: + raise LoadError("{}: {}".format(provenance, e), e.reason, detail=e.detail) from e + return self._loaded[key], file_path, current_loader # _process_value() diff --git a/tests/format/include.py b/tests/format/include.py index 5c273e1a0..d57dd8c19 100644 --- a/tests/format/include.py +++ b/tests/format/include.py @@ -44,7 +44,7 @@ def test_include_missing_file(cli, tmpdir): result = cli.run(project=str(tmpdir), args=["show", str(element.basename)]) result.assert_main_error(ErrorDomain.LOAD, LoadErrorReason.MISSING_FILE) # Make sure the root cause provenance is in the output. - assert "line 4 column 2" in result.stderr + assert "include_missing_file.bst [line 4 column 4]" in result.stderr def test_include_dir(cli, tmpdir): @@ -68,7 +68,7 @@ def test_include_dir(cli, tmpdir): result = cli.run(project=str(tmpdir), args=["show", str(element.basename)]) result.assert_main_error(ErrorDomain.LOAD, LoadErrorReason.LOADING_DIRECTORY) # Make sure the root cause provenance is in the output. - assert "line 4 column 2" in result.stderr + assert "include_dir.bst [line 4 column 4]" in result.stderr @pytest.mark.datafiles(DATA_DIR) |