summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChuanqi Xu <yedeng.yd@linux.alibaba.com>2023-02-10 10:11:33 +0800
committerTobias Hieta <tobias@hieta.se>2023-02-13 11:27:45 +0100
commit96faba7ee45b401e7c3e49649e81b662916c5b1a (patch)
tree883e1e55a3dd7eb72f4db3864be75d69f392472c
parentd0f4bebc40307399c4a498e055517f9594ac50af (diff)
downloadllvm-96faba7ee45b401e7c3e49649e81b662916c5b1a.tar.gz
[C++20] [Modules] [NFC] Add Preprocessor methods for named modules - for ClangScanDeps (1/4)
This patch prepares the necessary interfaces in the preprocessor part for D137527 since we need to recognize if we're in a module unit, the module kinds and the module declaration and the module we're importing in the preprocessor. Differential Revision: https://reviews.llvm.org/D137526
-rw-r--r--clang/include/clang/Lex/Preprocessor.h171
-rw-r--r--clang/lib/Lex/Preprocessor.cpp39
-rw-r--r--clang/unittests/Lex/CMakeLists.txt2
-rw-r--r--clang/unittests/Lex/ModuleDeclStateTest.cpp348
4 files changed, 555 insertions, 5 deletions
diff --git a/clang/include/clang/Lex/Preprocessor.h b/clang/include/clang/Lex/Preprocessor.h
index f383a2e5b530..0cb3769d33f4 100644
--- a/clang/include/clang/Lex/Preprocessor.h
+++ b/clang/include/clang/Lex/Preprocessor.h
@@ -313,6 +313,9 @@ private:
/// The import path for named module that we're currently processing.
SmallVector<std::pair<IdentifierInfo *, SourceLocation>, 2> NamedModuleImportPath;
+ /// Whether the import is an `@import` or a standard c++ modules import.
+ bool IsAtImport = false;
+
/// Whether the last token we lexed was an '@'.
bool LastTokenWasAt = false;
@@ -456,6 +459,144 @@ private:
TrackGMF TrackGMFState = TrackGMF::BeforeGMFIntroducer;
+ /// Track the status of the c++20 module decl.
+ ///
+ /// module-declaration:
+ /// 'export'[opt] 'module' module-name module-partition[opt]
+ /// attribute-specifier-seq[opt] ';'
+ ///
+ /// module-name:
+ /// module-name-qualifier[opt] identifier
+ ///
+ /// module-partition:
+ /// ':' module-name-qualifier[opt] identifier
+ ///
+ /// module-name-qualifier:
+ /// identifier '.'
+ /// module-name-qualifier identifier '.'
+ ///
+ /// Transition state:
+ ///
+ /// NotAModuleDecl --- export ---> FoundExport
+ /// NotAModuleDecl --- module ---> ImplementationCandidate
+ /// FoundExport --- module ---> InterfaceCandidate
+ /// ImplementationCandidate --- Identifier ---> ImplementationCandidate
+ /// ImplementationCandidate --- period ---> ImplementationCandidate
+ /// ImplementationCandidate --- colon ---> ImplementationCandidate
+ /// InterfaceCandidate --- Identifier ---> InterfaceCandidate
+ /// InterfaceCandidate --- period ---> InterfaceCandidate
+ /// InterfaceCandidate --- colon ---> InterfaceCandidate
+ /// ImplementationCandidate --- Semi ---> NamedModuleImplementation
+ /// NamedModuleInterface --- Semi ---> NamedModuleInterface
+ /// NamedModuleImplementation --- Anything ---> NamedModuleImplementation
+ /// NamedModuleInterface --- Anything ---> NamedModuleInterface
+ ///
+ /// FIXME: We haven't handle attribute-specifier-seq here. It may not be bad
+ /// soon since we don't support any module attributes yet.
+ class ModuleDeclSeq {
+ enum ModuleDeclState : int {
+ NotAModuleDecl,
+ FoundExport,
+ InterfaceCandidate,
+ ImplementationCandidate,
+ NamedModuleInterface,
+ NamedModuleImplementation,
+ };
+
+ public:
+ ModuleDeclSeq() : State(NotAModuleDecl) {}
+
+ void handleExport() {
+ if (State == NotAModuleDecl)
+ State = FoundExport;
+ else if (!isNamedModule())
+ reset();
+ }
+
+ void handleModule() {
+ if (State == FoundExport)
+ State = InterfaceCandidate;
+ else if (State == NotAModuleDecl)
+ State = ImplementationCandidate;
+ else if (!isNamedModule())
+ reset();
+ }
+
+ void handleIdentifier(IdentifierInfo *Identifier) {
+ if (isModuleCandidate() && Identifier)
+ Name += Identifier->getName().str();
+ else if (!isNamedModule())
+ reset();
+ }
+
+ void handleColon() {
+ if (isModuleCandidate())
+ Name += ":";
+ else if (!isNamedModule())
+ reset();
+ }
+
+ void handlePeriod() {
+ if (isModuleCandidate())
+ Name += ".";
+ else if (!isNamedModule())
+ reset();
+ }
+
+ void handleSemi() {
+ if (!Name.empty() && isModuleCandidate()) {
+ if (State == InterfaceCandidate)
+ State = NamedModuleInterface;
+ else if (State == ImplementationCandidate)
+ State = NamedModuleImplementation;
+ else
+ llvm_unreachable("Unimaged ModuleDeclState.");
+ } else if (!isNamedModule())
+ reset();
+ }
+
+ void handleMisc() {
+ if (!isNamedModule())
+ reset();
+ }
+
+ bool isModuleCandidate() const {
+ return State == InterfaceCandidate || State == ImplementationCandidate;
+ }
+
+ bool isNamedModule() const {
+ return State == NamedModuleInterface ||
+ State == NamedModuleImplementation;
+ }
+
+ bool isNamedInterface() const { return State == NamedModuleInterface; }
+
+ bool isImplementationUnit() const {
+ return State == NamedModuleImplementation && !getName().contains(':');
+ }
+
+ StringRef getName() const {
+ assert(isNamedModule() && "Can't get name from a non named module");
+ return Name;
+ }
+
+ StringRef getPrimaryName() const {
+ assert(isNamedModule() && "Can't get name from a non named module");
+ return getName().split(':').first;
+ }
+
+ void reset() {
+ Name.clear();
+ State = NotAModuleDecl;
+ }
+
+ private:
+ ModuleDeclState State;
+ std::string Name;
+ };
+
+ ModuleDeclSeq ModuleDeclState;
+
/// Whether the module import expects an identifier next. Otherwise,
/// it expects a '.' or ';'.
bool ModuleImportExpectsIdentifier = false;
@@ -2225,6 +2366,36 @@ public:
/// Retrieves the module whose implementation we're current compiling, if any.
Module *getCurrentModuleImplementation();
+ /// If we are preprocessing a named module.
+ bool isInNamedModule() const { return ModuleDeclState.isNamedModule(); }
+
+ /// If we are proprocessing a named interface unit.
+ /// Note that a module implementation partition is not considered as an
+ /// named interface unit here although it is importable
+ /// to ease the parsing.
+ bool isInNamedInterfaceUnit() const {
+ return ModuleDeclState.isNamedInterface();
+ }
+
+ /// Get the named module name we're preprocessing.
+ /// Requires we're preprocessing a named module.
+ StringRef getNamedModuleName() const { return ModuleDeclState.getName(); }
+
+ /// If we are implementing an implementation module unit.
+ /// Note that the module implementation partition is not considered as an
+ /// implementation unit.
+ bool isInImplementationUnit() const {
+ return ModuleDeclState.isImplementationUnit();
+ }
+
+ /// If we're importing a standard C++20 Named Modules.
+ bool isInImportingCXXNamedModules() const {
+ // NamedModuleImportPath will be non-empty only if we're importing
+ // Standard C++ named modules.
+ return !NamedModuleImportPath.empty() && getLangOpts().CPlusPlusModules &&
+ !IsAtImport;
+ }
+
/// Allocate a new MacroInfo object with the provided SourceLocation.
MacroInfo *AllocateMacroInfo(SourceLocation L);
diff --git a/clang/lib/Lex/Preprocessor.cpp b/clang/lib/Lex/Preprocessor.cpp
index fe9adb5685e3..d9a51b7e9da6 100644
--- a/clang/lib/Lex/Preprocessor.cpp
+++ b/clang/lib/Lex/Preprocessor.cpp
@@ -873,6 +873,7 @@ bool Preprocessor::HandleIdentifier(Token &Identifier) {
CurLexerKind != CLK_CachingLexer) {
ModuleImportLoc = Identifier.getLocation();
NamedModuleImportPath.clear();
+ IsAtImport = true;
ModuleImportExpectsIdentifier = true;
CurLexerKind = CLK_LexAfterModuleImport;
}
@@ -940,6 +941,7 @@ void Preprocessor::Lex(Token &Result) {
case tok::semi:
TrackGMFState.handleSemi();
StdCXXImportSeqState.handleSemi();
+ ModuleDeclState.handleSemi();
break;
case tok::header_name:
case tok::annot_header_unit:
@@ -948,6 +950,13 @@ void Preprocessor::Lex(Token &Result) {
case tok::kw_export:
TrackGMFState.handleExport();
StdCXXImportSeqState.handleExport();
+ ModuleDeclState.handleExport();
+ break;
+ case tok::colon:
+ ModuleDeclState.handleColon();
+ break;
+ case tok::period:
+ ModuleDeclState.handlePeriod();
break;
case tok::identifier:
if (Result.getIdentifierInfo()->isModulesImport()) {
@@ -956,18 +965,25 @@ void Preprocessor::Lex(Token &Result) {
if (StdCXXImportSeqState.afterImportSeq()) {
ModuleImportLoc = Result.getLocation();
NamedModuleImportPath.clear();
+ IsAtImport = false;
ModuleImportExpectsIdentifier = true;
CurLexerKind = CLK_LexAfterModuleImport;
}
break;
} else if (Result.getIdentifierInfo() == getIdentifierInfo("module")) {
TrackGMFState.handleModule(StdCXXImportSeqState.afterTopLevelSeq());
+ ModuleDeclState.handleModule();
break;
+ } else {
+ ModuleDeclState.handleIdentifier(Result.getIdentifierInfo());
+ if (ModuleDeclState.isModuleCandidate())
+ break;
}
[[fallthrough]];
default:
TrackGMFState.handleMisc();
StdCXXImportSeqState.handleMisc();
+ ModuleDeclState.handleMisc();
break;
}
}
@@ -1151,6 +1167,15 @@ bool Preprocessor::LexAfterModuleImport(Token &Result) {
if (NamedModuleImportPath.empty() && getLangOpts().CPlusPlusModules) {
if (LexHeaderName(Result))
return true;
+
+ if (Result.is(tok::colon) && ModuleDeclState.isNamedModule()) {
+ std::string Name = ModuleDeclState.getPrimaryName().str();
+ Name += ":";
+ NamedModuleImportPath.push_back(
+ {getIdentifierInfo(Name), Result.getLocation()});
+ CurLexerKind = CLK_LexAfterModuleImport;
+ return true;
+ }
} else {
Lex(Result);
}
@@ -1164,9 +1189,10 @@ bool Preprocessor::LexAfterModuleImport(Token &Result) {
/*DisableMacroExpansion*/ true, /*IsReinject*/ false);
};
+ bool ImportingHeader = Result.is(tok::header_name);
// Check for a header-name.
SmallVector<Token, 32> Suffix;
- if (Result.is(tok::header_name)) {
+ if (ImportingHeader) {
// Enter the header-name token into the token stream; a Lex action cannot
// both return a token and cache tokens (doing so would corrupt the token
// cache if the call to Lex comes from CachingLex / PeekAhead).
@@ -1244,8 +1270,8 @@ bool Preprocessor::LexAfterModuleImport(Token &Result) {
if (ModuleImportExpectsIdentifier && Result.getKind() == tok::identifier) {
// We expected to see an identifier here, and we did; continue handling
// identifiers.
- NamedModuleImportPath.push_back(std::make_pair(Result.getIdentifierInfo(),
- Result.getLocation()));
+ NamedModuleImportPath.push_back(
+ std::make_pair(Result.getIdentifierInfo(), Result.getLocation()));
ModuleImportExpectsIdentifier = false;
CurLexerKind = CLK_LexAfterModuleImport;
return true;
@@ -1285,7 +1311,8 @@ bool Preprocessor::LexAfterModuleImport(Token &Result) {
std::string FlatModuleName;
if (getLangOpts().ModulesTS || getLangOpts().CPlusPlusModules) {
for (auto &Piece : NamedModuleImportPath) {
- if (!FlatModuleName.empty())
+ // If the FlatModuleName ends with colon, it implies it is a partition.
+ if (!FlatModuleName.empty() && FlatModuleName.back() != ':')
FlatModuleName += ".";
FlatModuleName += Piece.first->getName();
}
@@ -1296,7 +1323,8 @@ bool Preprocessor::LexAfterModuleImport(Token &Result) {
}
Module *Imported = nullptr;
- if (getLangOpts().Modules) {
+ // We don't/shouldn't load the standard c++20 modules when preprocessing.
+ if (getLangOpts().Modules && !isInImportingCXXNamedModules()) {
Imported = TheModuleLoader.loadModule(ModuleImportLoc,
NamedModuleImportPath,
Module::Hidden,
@@ -1304,6 +1332,7 @@ bool Preprocessor::LexAfterModuleImport(Token &Result) {
if (Imported)
makeModuleVisible(Imported, SemiLoc);
}
+
if (Callbacks)
Callbacks->moduleImport(ModuleImportLoc, NamedModuleImportPath, Imported);
diff --git a/clang/unittests/Lex/CMakeLists.txt b/clang/unittests/Lex/CMakeLists.txt
index bed5fd9186f2..64ff794fb360 100644
--- a/clang/unittests/Lex/CMakeLists.txt
+++ b/clang/unittests/Lex/CMakeLists.txt
@@ -7,6 +7,7 @@ add_clang_unittest(LexTests
HeaderMapTest.cpp
HeaderSearchTest.cpp
LexerTest.cpp
+ ModuleDeclStateTest.cpp
PPCallbacksTest.cpp
PPConditionalDirectiveRecordTest.cpp
PPDependencyDirectivesTest.cpp
@@ -17,6 +18,7 @@ clang_target_link_libraries(LexTests
PRIVATE
clangAST
clangBasic
+ clangFrontend
clangLex
clangParse
clangSema
diff --git a/clang/unittests/Lex/ModuleDeclStateTest.cpp b/clang/unittests/Lex/ModuleDeclStateTest.cpp
new file mode 100644
index 000000000000..a8695391556f
--- /dev/null
+++ b/clang/unittests/Lex/ModuleDeclStateTest.cpp
@@ -0,0 +1,348 @@
+//===- unittests/Lex/ModuleDeclStateTest.cpp - PPCallbacks tests ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===--------------------------------------------------------------===//
+
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/DiagnosticOptions.h"
+#include "clang/Basic/FileManager.h"
+#include "clang/Basic/LangOptions.h"
+#include "clang/Basic/SourceManager.h"
+#include "clang/Basic/TargetInfo.h"
+#include "clang/Basic/TargetOptions.h"
+#include "clang/Frontend/CompilerInvocation.h"
+#include "clang/Lex/HeaderSearch.h"
+#include "clang/Lex/HeaderSearchOptions.h"
+#include "clang/Lex/ModuleLoader.h"
+#include "clang/Lex/Preprocessor.h"
+#include "clang/Lex/PreprocessorOptions.h"
+#include "gtest/gtest.h"
+#include <cstddef>
+#include <initializer_list>
+
+using namespace clang;
+
+namespace {
+
+class CheckNamedModuleImportingCB : public PPCallbacks {
+ Preprocessor &PP;
+ std::vector<bool> IsImportingNamedModulesAssertions;
+ std::size_t NextCheckingIndex;
+
+public:
+ CheckNamedModuleImportingCB(Preprocessor &PP,
+ std::initializer_list<bool> lists)
+ : PP(PP), IsImportingNamedModulesAssertions(lists), NextCheckingIndex(0) {
+ }
+
+ void moduleImport(SourceLocation ImportLoc, ModuleIdPath Path,
+ const Module *Imported) override {
+ ASSERT_TRUE(NextCheckingIndex < IsImportingNamedModulesAssertions.size());
+ EXPECT_EQ(PP.isInImportingCXXNamedModules(),
+ IsImportingNamedModulesAssertions[NextCheckingIndex]);
+ NextCheckingIndex++;
+
+ ASSERT_EQ(Imported, nullptr);
+ }
+
+ // Currently, only the named module will be handled by `moduleImport`
+ // callback.
+ std::size_t importNamedModuleNum() { return NextCheckingIndex; }
+};
+class ModuleDeclStateTest : public ::testing::Test {
+protected:
+ ModuleDeclStateTest()
+ : FileMgr(FileMgrOpts), DiagID(new DiagnosticIDs()),
+ Diags(DiagID, new DiagnosticOptions, new IgnoringDiagConsumer()),
+ SourceMgr(Diags, FileMgr), TargetOpts(new TargetOptions), Invocation() {
+ TargetOpts->Triple = "x86_64-unknown-linux-gnu";
+ Target = TargetInfo::CreateTargetInfo(Diags, TargetOpts);
+ }
+
+ LangOptions &getLangOpts(ArrayRef<const char *> CommandLineArgs) {
+ CompilerInvocation::CreateFromArgs(Invocation, CommandLineArgs, Diags);
+ return *Invocation.getLangOpts();
+ }
+
+ std::unique_ptr<Preprocessor>
+ getPreprocessor(const char *source, ArrayRef<const char *> CommandLineArgs) {
+ std::unique_ptr<llvm::MemoryBuffer> Buf =
+ llvm::MemoryBuffer::getMemBuffer(source);
+ SourceMgr.setMainFileID(SourceMgr.createFileID(std::move(Buf)));
+
+ LangOptions &LangOpts = getLangOpts(CommandLineArgs);
+ HeaderInfo.emplace(std::make_shared<HeaderSearchOptions>(), SourceMgr,
+ Diags, LangOpts, Target.get());
+
+ return std::make_unique<Preprocessor>(
+ std::make_shared<PreprocessorOptions>(), Diags, LangOpts, SourceMgr,
+ *HeaderInfo, ModLoader,
+ /*IILookup =*/nullptr,
+ /*OwnsHeaderSearch =*/false);
+ }
+
+ void preprocess(Preprocessor &PP, std::unique_ptr<PPCallbacks> C) {
+ PP.Initialize(*Target);
+ PP.addPPCallbacks(std::move(C));
+ PP.EnterMainSourceFile();
+
+ while (1) {
+ Token tok;
+ PP.Lex(tok);
+ if (tok.is(tok::eof))
+ break;
+ }
+ }
+
+ FileSystemOptions FileMgrOpts;
+ FileManager FileMgr;
+ IntrusiveRefCntPtr<DiagnosticIDs> DiagID;
+ DiagnosticsEngine Diags;
+ SourceManager SourceMgr;
+ std::shared_ptr<TargetOptions> TargetOpts;
+ IntrusiveRefCntPtr<TargetInfo> Target;
+ CompilerInvocation Invocation;
+ TrivialModuleLoader ModLoader;
+ std::optional<HeaderSearch> HeaderInfo;
+};
+
+TEST_F(ModuleDeclStateTest, NamedModuleInterface) {
+ const char *source = R"(
+export module foo;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 0);
+ EXPECT_TRUE(PP->isInNamedModule());
+ EXPECT_TRUE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+ EXPECT_EQ(PP->getNamedModuleName(), "foo");
+}
+
+TEST_F(ModuleDeclStateTest, NamedModuleImplementation) {
+ const char *source = R"(
+module foo;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 0);
+ EXPECT_TRUE(PP->isInNamedModule());
+ EXPECT_FALSE(PP->isInNamedInterfaceUnit());
+ EXPECT_TRUE(PP->isInImplementationUnit());
+ EXPECT_EQ(PP->getNamedModuleName(), "foo");
+}
+
+TEST_F(ModuleDeclStateTest, ModuleImplementationPartition) {
+ const char *source = R"(
+module foo:part;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 0);
+ EXPECT_TRUE(PP->isInNamedModule());
+ EXPECT_FALSE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+ EXPECT_EQ(PP->getNamedModuleName(), "foo:part");
+}
+
+TEST_F(ModuleDeclStateTest, ModuleInterfacePartition) {
+ const char *source = R"(
+export module foo:part;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 0);
+ EXPECT_TRUE(PP->isInNamedModule());
+ EXPECT_TRUE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+ EXPECT_EQ(PP->getNamedModuleName(), "foo:part");
+}
+
+TEST_F(ModuleDeclStateTest, ModuleNameWithDot) {
+ const char *source = R"(
+export module foo.dot:part.dot;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 0);
+ EXPECT_TRUE(PP->isInNamedModule());
+ EXPECT_TRUE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+ EXPECT_EQ(PP->getNamedModuleName(), "foo.dot:part.dot");
+}
+
+TEST_F(ModuleDeclStateTest, NotModule) {
+ const char *source = R"(
+// export module foo:part;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 0);
+ EXPECT_FALSE(PP->isInNamedModule());
+ EXPECT_FALSE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+}
+
+TEST_F(ModuleDeclStateTest, ModuleWithGMF) {
+ const char *source = R"(
+module;
+#include "bar.h"
+#include <zoo.h>
+import "bar";
+import <zoo>;
+export module foo:part;
+import "HU";
+import M;
+import :another;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {true, true};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 2);
+ EXPECT_TRUE(PP->isInNamedModule());
+ EXPECT_TRUE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+ EXPECT_EQ(PP->getNamedModuleName(), "foo:part");
+}
+
+TEST_F(ModuleDeclStateTest, ModuleWithGMFWithClangNamedModule) {
+ const char *source = R"(
+module;
+#include "bar.h"
+#include <zoo.h>
+import "bar";
+import <zoo>;
+export module foo:part;
+import "HU";
+import M;
+import :another;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {true, true};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 2);
+ EXPECT_TRUE(PP->isInNamedModule());
+ EXPECT_TRUE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+ EXPECT_EQ(PP->getNamedModuleName(), "foo:part");
+}
+
+TEST_F(ModuleDeclStateTest, ImportsInNormalTU) {
+ const char *source = R"(
+#include "bar.h"
+#include <zoo.h>
+import "bar";
+import <zoo>;
+import "HU";
+import M;
+// We can't import a partition in non-module TU.
+import :another;
+ )";
+ std::unique_ptr<Preprocessor> PP = getPreprocessor(source, "-std=c++20");
+
+ std::initializer_list<bool> ImportKinds = {true};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 1);
+ EXPECT_FALSE(PP->isInNamedModule());
+ EXPECT_FALSE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+}
+
+TEST_F(ModuleDeclStateTest, ImportAClangNamedModule) {
+ const char *source = R"(
+@import anything;
+ )";
+ std::unique_ptr<Preprocessor> PP =
+ getPreprocessor(source, {"-fmodules", "-fimplicit-module-maps", "-x",
+ "objective-c++", "-std=c++20"});
+
+ std::initializer_list<bool> ImportKinds = {false};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 1);
+ EXPECT_FALSE(PP->isInNamedModule());
+ EXPECT_FALSE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+}
+
+TEST_F(ModuleDeclStateTest, ImportWixedForm) {
+ const char *source = R"(
+import "HU";
+@import anything;
+import M;
+@import another;
+import M2;
+ )";
+ std::unique_ptr<Preprocessor> PP =
+ getPreprocessor(source, {"-fmodules", "-fimplicit-module-maps", "-x",
+ "objective-c++", "-std=c++20"});
+
+ std::initializer_list<bool> ImportKinds = {false, true, false, true};
+ preprocess(*PP,
+ std::make_unique<CheckNamedModuleImportingCB>(*PP, ImportKinds));
+
+ auto *Callback =
+ static_cast<CheckNamedModuleImportingCB *>(PP->getPPCallbacks());
+ EXPECT_EQ(Callback->importNamedModuleNum(), 4);
+ EXPECT_FALSE(PP->isInNamedModule());
+ EXPECT_FALSE(PP->isInNamedInterfaceUnit());
+ EXPECT_FALSE(PP->isInImplementationUnit());
+}
+
+} // namespace