From 122d5ae2a52cc9ed1f6a0a68311684dac1b93f2e Mon Sep 17 00:00:00 2001 From: Devajith Date: Mon, 7 Nov 2022 23:42:53 +0000 Subject: [PATCH 01/12] [mlir][mlir-query] Introduce mlir-query tool with autocomplete support This commit adds the initial version of the mlir-query tool, which leverages the pre-existing matchers defined in mlir/include/mlir/IR/Matchers.h The tool provides the following set of basic queries: QUERY MATCHER hasOpAttrName(string) -> m_Attr hasOpName(string) -> m_Op isConstantOp() -> m_Constant isNegInfFloat() -> m_NegInfFloat isNegZeroFloat() -> m_NegZeroFloat isNonZero() -> m_NonZero isOne() -> m_One isOneFloat() -> m_OneFloat isPosInfFloat() -> m_PosInfFloat isPosZeroFloat() -> m_PosZeroFloat isZero() -> m_Zero isZeroFloat() -> m_AnyZeroFloat Differential Revision: https://reviews.llvm.org/D155127 --- mlir/include/mlir/Query/Matcher/Diagnostics.h | 152 +++++ mlir/include/mlir/Query/Matcher/Marshallers.h | 195 ++++++ mlir/include/mlir/Query/Matcher/MatchFinder.h | 42 ++ .../mlir/Query/Matcher/MatchersInternal.h | 72 +++ mlir/include/mlir/Query/Matcher/Parser.h | 174 ++++++ mlir/include/mlir/Query/Matcher/Registry.h | 66 +++ .../include/mlir/Query/Matcher/VariantValue.h | 141 +++++ mlir/include/mlir/Query/Query.h | 79 +++ mlir/include/mlir/Query/QueryParser.h | 59 ++ mlir/include/mlir/Query/QuerySession.h | 40 ++ .../mlir/Tools/mlir-query/MlirQueryMain.h | 27 + mlir/lib/CMakeLists.txt | 1 + mlir/lib/Query/CMakeLists.txt | 12 + mlir/lib/Query/Matcher/CMakeLists.txt | 9 + mlir/lib/Query/Matcher/Diagnostics.cpp | 201 +++++++ mlir/lib/Query/Matcher/Parser.cpp | 553 ++++++++++++++++++ mlir/lib/Query/Matcher/Registry.cpp | 171 ++++++ mlir/lib/Query/Matcher/VariantValue.cpp | 139 +++++ mlir/lib/Query/Query.cpp | 59 ++ mlir/lib/Query/QueryParser.cpp | 208 +++++++ mlir/lib/Tools/CMakeLists.txt | 1 + mlir/lib/Tools/mlir-query/CMakeLists.txt | 13 + mlir/lib/Tools/mlir-query/MlirQueryMain.cpp | 110 ++++ mlir/test/CMakeLists.txt | 1 + mlir/test/mlir-query/simple-test.mlir | 16 + mlir/tools/CMakeLists.txt | 1 + mlir/tools/mlir-query/CMakeLists.txt | 20 + mlir/tools/mlir-query/mlir-query.cpp | 37 ++ 28 files changed, 2599 insertions(+) create mode 100644 mlir/include/mlir/Query/Matcher/Diagnostics.h create mode 100644 mlir/include/mlir/Query/Matcher/Marshallers.h create mode 100644 mlir/include/mlir/Query/Matcher/MatchFinder.h create mode 100644 mlir/include/mlir/Query/Matcher/MatchersInternal.h create mode 100644 mlir/include/mlir/Query/Matcher/Parser.h create mode 100644 mlir/include/mlir/Query/Matcher/Registry.h create mode 100644 mlir/include/mlir/Query/Matcher/VariantValue.h create mode 100644 mlir/include/mlir/Query/Query.h create mode 100644 mlir/include/mlir/Query/QueryParser.h create mode 100644 mlir/include/mlir/Query/QuerySession.h create mode 100644 mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h create mode 100644 mlir/lib/Query/CMakeLists.txt create mode 100644 mlir/lib/Query/Matcher/CMakeLists.txt create mode 100644 mlir/lib/Query/Matcher/Diagnostics.cpp create mode 100644 mlir/lib/Query/Matcher/Parser.cpp create mode 100644 mlir/lib/Query/Matcher/Registry.cpp create mode 100644 mlir/lib/Query/Matcher/VariantValue.cpp create mode 100644 mlir/lib/Query/Query.cpp create mode 100644 mlir/lib/Query/QueryParser.cpp create mode 100644 mlir/lib/Tools/mlir-query/CMakeLists.txt create mode 100644 mlir/lib/Tools/mlir-query/MlirQueryMain.cpp create mode 100644 mlir/test/mlir-query/simple-test.mlir create mode 100644 mlir/tools/mlir-query/CMakeLists.txt create mode 100644 mlir/tools/mlir-query/mlir-query.cpp diff --git a/mlir/include/mlir/Query/Matcher/Diagnostics.h b/mlir/include/mlir/Query/Matcher/Diagnostics.h new file mode 100644 index 0000000000000..35f29721b1f82 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/Diagnostics.h @@ -0,0 +1,152 @@ +//===--- Diagnostics.h - Helper class for error diagnostics -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Diagnostics class to manage error messages. Implementation shares similarity +// to clang-query Diagnostics. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +namespace mlir::query::matcher { + +// Represents the line and column numbers in a source query. +struct SourceLocation { + unsigned line{}; + unsigned column{}; +}; + +// Represents a range in a source query, defined by its start and end locations. +struct SourceRange { + SourceLocation start{}; + SourceLocation end{}; +}; + +// Diagnostics class to manage error messages. +class Diagnostics { +public: + // Parser context types. + enum ContextType { CT_MatcherArg, CT_MatcherConstruct }; + + // All errors from the system. + enum ErrorType { + ET_None, + + // Parser Errors + ET_ParserFailedToBuildMatcher, + ET_ParserInvalidToken, + ET_ParserNoCloseParen, + ET_ParserNoCode, + ET_ParserNoComma, + ET_ParserNoOpenParen, + ET_ParserNotAMatcher, + ET_ParserOverloadedType, + ET_ParserStringError, + ET_ParserTrailingCode, + + // Registry Errors + ET_RegistryMatcherNotFound, + ET_RegistryValueNotFound, + ET_RegistryWrongArgCount, + ET_RegistryWrongArgType + }; + + // Helper stream class for constructing error messages. + class ArgStream { + public: + ArgStream(std::vector *out) : out(out) {} + template + ArgStream &operator<<(const T &arg) { + return operator<<(llvm::Twine(arg)); + } + ArgStream &operator<<(const llvm::Twine &arg); + + private: + std::vector *out; + }; + + // Context for constructing a matcher or parsing its argument. + struct Context { + enum ConstructMatcherEnum { ConstructMatcher }; + Context(ConstructMatcherEnum, Diagnostics *error, + llvm::StringRef matcherName, SourceRange matcherRange); + enum MatcherArgEnum { MatcherArg }; + Context(MatcherArgEnum, Diagnostics *error, llvm::StringRef matcherName, + SourceRange matcherRange, int argNumber); + ~Context(); + + private: + Diagnostics *const error; + }; + + // Context for managing overloaded matcher construction. + struct OverloadContext { + // Construct an overload context with the given error. + OverloadContext(Diagnostics *error); + ~OverloadContext(); + // Revert all errors that occurred within this context. + void revertErrors(); + + private: + Diagnostics *const error; + unsigned beginIndex{}; + }; + + // Add an error message with the specified range and error type. + // Returns an ArgStream object to allow constructing the error message using + // the << operator. + ArgStream addError(SourceRange range, ErrorType error); + + // Information stored for one frame of the context. + struct ContextFrame { + ContextType type; + SourceRange range; + std::vector args; + }; + + // Information stored for each error found. + struct ErrorContent { + std::vector contextStack; + struct Message { + SourceRange range; + ErrorType type; + std::vector args; + }; + std::vector messages; + }; + + // Get an array reference to the error contents. + llvm::ArrayRef errors() const { return errorValues; } + + // Print all error messages to the specified output stream. + void print(llvm::raw_ostream &OS) const; + + // Print the full error messages, including the context information, to the + // specified output stream. + void printFull(llvm::raw_ostream &OS) const; + +private: + // Push a new context frame onto the context stack with the specified type and + // range. + ArgStream pushContextFrame(ContextType type, SourceRange range); + + std::vector contextStack; + std::vector errorValues; +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h new file mode 100644 index 0000000000000..14f6507041a68 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/Marshallers.h @@ -0,0 +1,195 @@ +//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains function templates and classes to wrap matcher construct +// functions. It provides a collection of template function and classes that +// present a generic marshalling layer on top of matcher construct functions. +// The registry uses these to export all marshaller constructors with a uniform +// interface. This mechanism takes inspiration from clang-query. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H + +#include "Diagnostics.h" +#include "VariantValue.h" + +namespace mlir::query::matcher::internal { + +// Helper template class for jumping from argument type to the correct is/get +// functions in VariantValue. This is used for verifying and extracting the +// matcher arguments. +template +struct ArgTypeTraits; +template +struct ArgTypeTraits : public ArgTypeTraits {}; + +template <> +struct ArgTypeTraits { + + static bool hasCorrectType(const VariantValue &value) { + return value.isString(); + } + + static const StringRef &get(const VariantValue &value) { + return value.getString(); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_String); } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> +struct ArgTypeTraits { + + static bool hasCorrectType(const VariantValue &value) { + return value.isMatcher(); + } + + static DynMatcher get(const VariantValue &value) { + return *value.getMatcher().getDynMatcher(); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_Matcher); } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +// Interface for generic matcher descriptor. +// Offers a create() method that constructs the matcher from the provided +// arguments. +class MatcherDescriptor { +public: + virtual ~MatcherDescriptor() = default; + virtual VariantMatcher create(SourceRange nameRange, + const ArrayRef args, + Diagnostics *error) const = 0; + + // Returns the number of arguments accepted by the matcher. + virtual unsigned getNumArgs() const = 0; + + // Append the set of argument types accepted for argument 'ArgNo' to + // 'ArgKinds'. + virtual void getArgKinds(unsigned argNo, + std::vector &argKinds) const = 0; +}; + +class FixedArgCountMatcherDescriptor : public MatcherDescriptor { +public: + using MarshallerType = VariantMatcher (*)(void (*func)(), + StringRef matcherName, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error); + + // Marshaller Function to unpack the arguments and call Func. Func is the + // Matcher construct function. This is the function that the matcher + // expressions would use to create the matcher. + FixedArgCountMatcherDescriptor(MarshallerType marshaller, void (*func)(), + StringRef matcherName, + ArrayRef argKinds) + : marshaller(marshaller), func(func), matcherName(matcherName), + argKinds(argKinds.begin(), argKinds.end()) {} + + VariantMatcher create(SourceRange nameRange, ArrayRef args, + Diagnostics *error) const override { + return marshaller(func, matcherName, nameRange, args, error); + } + + unsigned getNumArgs() const override { return argKinds.size(); } + + void getArgKinds(unsigned argNo, std::vector &kinds) const override { + kinds.push_back(argKinds[argNo]); + } + +private: + const MarshallerType marshaller; + void (*const func)(); + const StringRef matcherName; + const std::vector argKinds; +}; + +// Helper function to check if argument count matches expected count +inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, + ArrayRef args, Diagnostics *error) { + if (args.size() != expectedArgCount) { + error->addError(nameRange, error->ET_RegistryWrongArgCount) + << expectedArgCount << args.size(); + return false; + } + return true; +} + +// Helper function for checking argument type +template +inline bool checkArgTypeAtIndex(StringRef matcherName, + ArrayRef args, + Diagnostics *error) { + if (!ArgTypeTraits::hasCorrectType(args[Index].value)) { + error->addError(args[Index].range, error->ET_RegistryWrongArgType) + << matcherName << Index + 1; + return false; + } + return true; +} + +// Marshaller function for fixed number of arguments +template +static VariantMatcher +matcherMarshallFixedImpl(void (*func)(), StringRef matcherName, + SourceRange nameRange, ArrayRef args, + Diagnostics *error, std::index_sequence) { + using FuncType = ReturnType (*)(ArgTypes...); + + // Check if the argument count matches the expected count + if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error)) { + return VariantMatcher(); + } + + // Check if each argument at the corresponding index has the correct type + if ((... && checkArgTypeAtIndex(matcherName, args, error))) { + ReturnType fnPointer = reinterpret_cast(func)( + ArgTypeTraits::get(args[Is].value)...); + return VariantMatcher::SingleMatcher( + *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer)); + } else { + return VariantMatcher(); + } +} + +template +static VariantMatcher +matcherMarshallFixed(void (*func)(), StringRef matcherName, + SourceRange nameRange, ArrayRef args, + Diagnostics *error) { + return matcherMarshallFixedImpl( + func, matcherName, nameRange, args, error, + std::index_sequence_for{}); +} + +// Fixed number of arguments overload +template +std::unique_ptr +makeMatcherAutoMarshall(ReturnType (*func)(ArgTypes...), + StringRef matcherName) { + // Create a vector of argument kinds + std::vector argKinds = {ArgTypeTraits::getKind()...}; + return std::make_unique( + matcherMarshallFixed, + reinterpret_cast(func), matcherName, argKinds); +} + +} // namespace mlir::query::matcher::internal + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h new file mode 100644 index 0000000000000..5a87e45310920 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h @@ -0,0 +1,42 @@ + +//===- MatchFinder.h - Structural query framework ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the MatchFinder class, which is used to find operations +// that match a given matcher. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H + +#include "MatchersInternal.h" + +namespace mlir::query::matcher { + +// MatchFinder is used to find all operations that match a given matcher. +class MatchFinder { +public: + // Returns all operations that match the given matcher. + static std::vector getMatches(Operation *root, + DynMatcher matcher) { + std::vector matches; + + // Simple match finding with walk. + root->walk([&](Operation *subOp) { + if (matcher.match(subOp)) + matches.push_back(subOp); + }); + + return matches; + } +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H \ No newline at end of file diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h new file mode 100644 index 0000000000000..67455be592393 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h @@ -0,0 +1,72 @@ +//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements the base layer of the matcher framework. +// +// Matchers are methods that return a Matcher which provides a method +// match(Operation *op) +// +// The matcher functions are defined in include/mlir/IR/Matchers.h. +// This file contains the wrapper classes needed to construct matchers for +// mlir-query. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H + +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" + +namespace mlir::query::matcher { + +// Generic interface for matchers on an MLIR operation. +class MatcherInterface + : public llvm::ThreadSafeRefCountedBase { +public: + virtual ~MatcherInterface() = default; + + virtual bool match(Operation *op) = 0; +}; + +// MatcherFnImpl takes a matcher function object and implements +// MatcherInterface. +template +class MatcherFnImpl : public MatcherInterface { +public: + MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {} + bool match(Operation *op) override { return matcherFn.match(op); } + +private: + MatcherFn matcherFn; +}; + +// Matcher wraps a MatcherInterface implementation and provides a match() +// method that redirects calls to the underlying implementation. +class DynMatcher { +public: + // Takes ownership of the provided implementation pointer. + DynMatcher(MatcherInterface *implementation) + : implementation(implementation) {} + + template + static std::unique_ptr + constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) { + auto impl = std::make_unique>(matcherFn); + return std::make_unique(impl.release()); + } + + bool match(Operation *op) const { return implementation->match(op); } + +private: + llvm::IntrusiveRefCntPtr implementation; +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H diff --git a/mlir/include/mlir/Query/Matcher/Parser.h b/mlir/include/mlir/Query/Matcher/Parser.h new file mode 100644 index 0000000000000..232ab20d52189 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/Parser.h @@ -0,0 +1,174 @@ +//===--- Parser.h - Matcher expression parser -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Simple matcher expression parser. +// +// This file contains the Parser class, which is responsible for parsing +// expressions in a specific format: matcherName(Arg0, Arg1, ..., ArgN). The +// parser can also interpret simple types, like strings. +// +// The actual processing of the matchers is handled by a Sema object that is +// provided to the parser. +// +// The grammar for the supported expressions is as follows: +// := | +// := "quoted string" +// := () +// := [a-zA-Z]+ +// := | , +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H + +#include "Diagnostics.h" +#include "Registry.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir::query::matcher { + +// Matcher expression parser. +class Parser { +public: + // Interface to connect the parser with the registry and more. The parser uses + // the Sema instance passed into parseMatcherExpression() to handle all + // matcher tokens. + class Sema { + public: + virtual ~Sema(); + + // Process a matcher expression. The caller takes ownership of the Matcher + // object returned. + virtual VariantMatcher actOnMatcherExpression(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error) = 0; + + // Look up a matcher by name in the matcher name found by the parser. + virtual std::optional + lookupMatcherCtor(llvm::StringRef matcherName) = 0; + + // Compute the list of completion types for Context. + virtual std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> Context); + + // Compute the list of completions that match any of acceptedTypes. + virtual std::vector + getMatcherCompletions(llvm::ArrayRef acceptedTypes); + }; + + // An implementation of the Sema interface that uses the matcher registry to + // process tokens. + class RegistrySema : public Parser::Sema { + public: + ~RegistrySema() override; + + std::optional + lookupMatcherCtor(llvm::StringRef matcherName) override; + + VariantMatcher actOnMatcherExpression(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error) override; + + std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> context) override; + + std::vector + getMatcherCompletions(llvm::ArrayRef acceptedTypes) override; + }; + + using NamedValueMap = llvm::StringMap; + + // Methods to parse a matcher expression and return a DynMatcher object, + // transferring ownership to the caller. + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema, + const NamedValueMap *namedValues, Diagnostics *error); + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema, + Diagnostics *error) { + return parseMatcherExpression(matcherCode, sema, nullptr, error); + } + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, Diagnostics *error) { + return parseMatcherExpression(matcherCode, nullptr, error); + } + + // Methods to parse any expression supported by this parser. + static bool parseExpression(llvm::StringRef &code, Sema *sema, + const NamedValueMap *namedValues, + VariantValue *value, Diagnostics *error); + + static bool parseExpression(llvm::StringRef &code, Sema *sema, + VariantValue *value, Diagnostics *error) { + return parseExpression(code, sema, nullptr, value, error); + } + static bool parseExpression(llvm::StringRef &code, VariantValue *value, + Diagnostics *error) { + return parseExpression(code, nullptr, value, error); + } + + // Methods to complete an expression at a given offset. + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset, + Sema *sema, const NamedValueMap *namedValues); + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset, + Sema *sema) { + return completeExpression(code, completionOffset, sema, nullptr); + } + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset) { + return completeExpression(code, completionOffset, nullptr); + } + +private: + class CodeTokenizer; + struct ScopedContextEntry; + struct TokenInfo; + + Parser(CodeTokenizer *tokenizer, Sema *sema, const NamedValueMap *namedValues, + Diagnostics *error); + + bool parseExpressionImpl(VariantValue *value); + + bool parseMatcherArgs(std::vector &args, MatcherCtor ctor, + const TokenInfo &nameToken, TokenInfo &endToken); + + bool parseMatcherExpressionImpl(const TokenInfo &nameToken, + const TokenInfo &openToken, + std::optional ctor, + VariantValue *value); + + bool parseIdentifierPrefixImpl(VariantValue *value); + + void addCompletion(const TokenInfo &compToken, + const MatcherCompletion &completion); + void addExpressionCompletions(); + + std::vector + getNamedValueCompletions(ArrayRef acceptedTypes); + + CodeTokenizer *const tokenizer; + Sema *const sema; + const NamedValueMap *const namedValues; + Diagnostics *const error; + + using ContextStackTy = std::vector>; + + ContextStackTy contextStack; + std::vector completions; +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H diff --git a/mlir/include/mlir/Query/Matcher/Registry.h b/mlir/include/mlir/Query/Matcher/Registry.h new file mode 100644 index 0000000000000..4bfa1a0c1ab83 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/Registry.h @@ -0,0 +1,66 @@ +//===--- Registry.h - Matcher registry --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registry of all known matchers. +// +// The registry provides a generic interface to construct any matcher by name. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H + +#include "Diagnostics.h" +#include "Marshallers.h" +#include "VariantValue.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace mlir::query::matcher { + +using MatcherCtor = const internal::MatcherDescriptor *; + +struct MatcherCompletion { + MatcherCompletion() = default; + MatcherCompletion(llvm::StringRef typedText, llvm::StringRef matcherDecl) + : typedText(typedText.str()), matcherDecl(matcherDecl.str()) {} + + bool operator==(const MatcherCompletion &other) const { + return typedText == other.typedText && matcherDecl == other.matcherDecl; + } + + // The text to type to select this matcher. + std::string typedText; + + // The "declaration" of the matcher, with type information. + std::string matcherDecl; +}; + +class Registry { +public: + Registry() = delete; + + static std::optional + lookupMatcherCtor(llvm::StringRef matcherName); + + static std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> context); + + static std::vector + getMatcherCompletions(ArrayRef acceptedTypes); + + static VariantMatcher constructMatcher(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error); +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h new file mode 100644 index 0000000000000..22182c17319f9 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -0,0 +1,141 @@ +//===--- VariantValue.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Supports all the types required for dynamic Matcher construction. +// Used by the registry to construct matchers in a generic way. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H + +#include "Diagnostics.h" +#include "MatchersInternal.h" + +namespace mlir::query::matcher { + +// Kind identifier that supports all types that VariantValue can contain. +class ArgKind { +public: + enum Kind { AK_Matcher, AK_String }; + ArgKind(Kind k) : k(k) {} + + Kind getArgKind() const { return k; } + + bool operator<(const ArgKind &other) const { return k < other.k; } + + // String representation of the type. + std::string asString() const; + +private: + Kind k; +}; + +// A variant matcher object to abstract simple and complex matchers into a +// single object type. +class VariantMatcher { + class MatcherOps; + + // Payload interface to be specialized by each matcher type. It follows a + // similar interface as VariantMatcher itself. + class Payload { + public: + virtual ~Payload(); + virtual std::optional getDynMatcher() const = 0; + virtual std::string getTypeAsString() const = 0; + }; + +public: + /// A null matcher. + VariantMatcher(); + + // Clones the provided matcher. + static VariantMatcher SingleMatcher(DynMatcher matcher); + + // Makes the matcher the "null" matcher. + void reset(); + + // Checks if the matcher is null. + bool isNull() const { return !value; } + + /// Returns the matcher + std::optional getDynMatcher() const; + + // String representation of the type of the value. + std::string getTypeAsString() const; + +private: + explicit VariantMatcher(std::shared_ptr value) + : value(std::move(value)) {} + + class SinglePayload; + + std::shared_ptr value; +}; + +// Variant value class with a tagged union with value type semantics. It is used +// by the registry as the return value and argument type for the matcher factory +// methods. It can be constructed from any of the supported types: +// - StringRef +// - VariantMatcher +class VariantValue { +public: + VariantValue() : type(VT_Nothing) {} + + VariantValue(const VariantValue &other); + ~VariantValue(); + VariantValue &operator=(const VariantValue &other); + + // Specific constructors for each supported type. + VariantValue(const StringRef string); + VariantValue(const VariantMatcher &matcher); + + // String value functions. + bool isString() const; + const StringRef &getString() const; + void setString(const StringRef &string); + + // Matcher value functions. + bool isMatcher() const; + const VariantMatcher &getMatcher() const; + void setMatcher(const VariantMatcher &matcher); + + // String representation of the type of the value. + std::string getTypeAsString() const; + +private: + void reset(); + + // All supported value types. + enum ValueType { + VT_Nothing, + VT_String, + VT_Matcher, + }; + + // All supported value types. + union AllValues { + StringRef *String; + VariantMatcher *Matcher; + }; + + ValueType type; + AllValues value; +}; + +// A VariantValue instance annotated with its parser context. +struct ParserValue { + ParserValue() {} + llvm::StringRef text; + SourceRange range; + VariantValue value; +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h new file mode 100644 index 0000000000000..77cda9853b69a --- /dev/null +++ b/mlir/include/mlir/Query/Query.h @@ -0,0 +1,79 @@ +//===--- Query.h - mlir-query -----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERY_H +#define MLIR_TOOLS_MLIRQUERY_QUERY_H + +#include "Matcher/VariantValue.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" +#include "llvm/ADT/Twine.h" +#include + +namespace mlir::query { + +enum QueryKind { QK_Invalid, QK_NoOp, QK_Help, QK_Match }; + +class QuerySession; + +struct Query : llvm::RefCountedBase { + Query(QueryKind kind) : kind(kind) {} + virtual ~Query(); + + // Perform the query on QS and print output to OS. + // Return false if an error occurs, otherwise return true. + virtual bool run(llvm::raw_ostream &OS, QuerySession &QS) const = 0; + + llvm::StringRef remainingContent; + const QueryKind kind; +}; + +typedef llvm::IntrusiveRefCntPtr QueryRef; + +// Any query which resulted in a parse error. The error message is in ErrStr. +struct InvalidQuery : Query { + InvalidQuery(const llvm::Twine &errStr) + : Query(QK_Invalid), errStr(errStr.str()) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + std::string errStr; + + static bool classof(const Query *query) { return query->kind == QK_Invalid; } +}; + +// No-op query (i.e. a blank line). +struct NoOpQuery : Query { + NoOpQuery() : Query(QK_NoOp) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + static bool classof(const Query *query) { return query->kind == QK_NoOp; } +}; + +// Query for "help". +struct HelpQuery : Query { + HelpQuery() : Query(QK_Help) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + static bool classof(const Query *query) { return query->kind == QK_Help; } +}; + +// Query for "match MATCHER". +struct MatchQuery : Query { + MatchQuery(StringRef source, const matcher::DynMatcher &matcher) + : Query(QK_Match), matcher(matcher), source(source) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + const matcher::DynMatcher matcher; + + StringRef source; + + static bool classof(const Query *query) { return query->kind == QK_Match; } +}; + +} // namespace mlir::query + +#endif diff --git a/mlir/include/mlir/Query/QueryParser.h b/mlir/include/mlir/Query/QueryParser.h new file mode 100644 index 0000000000000..84c63080cca7c --- /dev/null +++ b/mlir/include/mlir/Query/QueryParser.h @@ -0,0 +1,59 @@ +//===--- QueryParser.h - mlir-query -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H +#define MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H + +#include "Matcher/Parser.h" +#include "Query.h" +#include "QuerySession.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/LineEditor/LineEditor.h" + +namespace mlir::query { + +class QuerySession; + +class QueryParser { +public: + // Parse line as a query and return a QueryRef representing the query, which + // may be an InvalidQuery. + static QueryRef parse(llvm::StringRef line, const QuerySession &QS); + + static std::vector + complete(llvm::StringRef line, size_t pos, const QuerySession &QS); + +private: + QueryParser(llvm::StringRef line, const QuerySession &QS) + : line(line), completionPos(nullptr), QS(QS) {} + + llvm::StringRef lexWord(); + + template + struct LexOrCompleteWord; + + QueryRef completeMatcherExpression(); + + QueryRef endQuery(QueryRef queryRef); + + // Parse [Begin, End) and returns a reference to the parsed query object, + // which may be an InvalidQuery if a parse error occurs. + QueryRef doParse(); + + llvm::StringRef line; + + const char *completionPos; + std::vector completions; + + const QuerySession &QS; +}; + +} // namespace mlir::query + +#endif // MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h new file mode 100644 index 0000000000000..afe3e3b26c7a1 --- /dev/null +++ b/mlir/include/mlir/Query/QuerySession.h @@ -0,0 +1,40 @@ +//===--- QuerySession.h - mlir-query ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H +#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H + +#include "Query.h" +#include "mlir/Tools/ParseUtilities.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir::query { + +// Represents the state for a particular mlir-query session. +class QuerySession { +public: + QuerySession(Operation *rootOp, + const std::shared_ptr &sourceMgr, + unsigned bufferId) + : rootOp(rootOp), sourceMgr(sourceMgr), bufferId(bufferId), + terminate(false) {} + + const std::shared_ptr &getSourceManager() { + return sourceMgr; + } + + Operation *rootOp; + const std::shared_ptr sourceMgr; + unsigned bufferId; + bool terminate; + llvm::StringMap namedValues; +}; + +} // namespace mlir::query + +#endif // MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H diff --git a/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h new file mode 100644 index 0000000000000..1fa5bc2b78605 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h @@ -0,0 +1,27 @@ +//===- MlirQueryMain.h - MLIR Query main ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Main entry function for mlir-query for when built as standalone +// binary. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H +#define MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H + +#include "mlir/Support/LogicalResult.h" + +namespace mlir { + +class MLIRContext; + +LogicalResult mlirQueryMain(int argc, char **argv, MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index c71664a3f0063..d25c84a3975db 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(IR) add_subdirectory(Interfaces) add_subdirectory(Parser) add_subdirectory(Pass) +add_subdirectory(Query) add_subdirectory(Reducer) add_subdirectory(Rewrite) add_subdirectory(Support) diff --git a/mlir/lib/Query/CMakeLists.txt b/mlir/lib/Query/CMakeLists.txt new file mode 100644 index 0000000000000..817583e94c522 --- /dev/null +++ b/mlir/lib/Query/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_library(MLIRQuery + Query.cpp + QueryParser.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query + + LINK_LIBS PUBLIC + MLIRQueryMatcher + ) + +add_subdirectory(Matcher) diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt new file mode 100644 index 0000000000000..f2a9abeadb5f6 --- /dev/null +++ b/mlir/lib/Query/Matcher/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_library(MLIRQueryMatcher + Parser.cpp + Registry.cpp + VariantValue.cpp + Diagnostics.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query/Matcher + ) diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp new file mode 100644 index 0000000000000..aa9685ee1e436 --- /dev/null +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -0,0 +1,201 @@ +//===- MatcherDiagnostic.cpp ------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/Diagnostics.h" + +namespace mlir::query::matcher { + +Diagnostics::ArgStream Diagnostics::pushContextFrame(ContextType type, + SourceRange range) { + contextStack.emplace_back(); + ContextFrame &data = contextStack.back(); + data.type = type; + data.range = range; + return ArgStream(&data.args); +} + +Diagnostics::Context::Context(ConstructMatcherEnum, Diagnostics *error, + llvm::StringRef matcherName, + SourceRange matcherRange) + : error(error) { + error->pushContextFrame(CT_MatcherConstruct, matcherRange) << matcherName; +} + +Diagnostics::Context::Context(MatcherArgEnum, Diagnostics *error, + llvm::StringRef matcherName, + SourceRange matcherRange, int argnumber) + : error(error) { + error->pushContextFrame(CT_MatcherArg, matcherRange) + << argnumber << matcherName; +} + +Diagnostics::Context::~Context() { error->contextStack.pop_back(); } + +Diagnostics::OverloadContext::OverloadContext(Diagnostics *error) + : error(error), beginIndex(error->errorValues.size()) {} + +Diagnostics::OverloadContext::~OverloadContext() { + // Merge all errors that happened while in this context. + if (beginIndex < error->errorValues.size()) { + Diagnostics::ErrorContent &dest = error->errorValues[beginIndex]; + for (size_t i = beginIndex + 1, e = error->errorValues.size(); i < e; ++i) { + dest.messages.push_back(error->errorValues[i].messages[0]); + } + error->errorValues.resize(beginIndex + 1); + } +} + +void Diagnostics::OverloadContext::revertErrors() { + // Revert the errors. + error->errorValues.resize(beginIndex); +} + +Diagnostics::ArgStream & +Diagnostics::ArgStream::operator<<(const llvm::Twine &arg) { + out->push_back(arg.str()); + return *this; +} + +Diagnostics::ArgStream Diagnostics::addError(SourceRange range, + ErrorType error) { + errorValues.emplace_back(); + ErrorContent &last = errorValues.back(); + last.contextStack = contextStack; + last.messages.emplace_back(); + last.messages.back().range = range; + last.messages.back().type = error; + return ArgStream(&last.messages.back().args); +} + +static llvm::StringRef +contextTypeToFormatString(Diagnostics::ContextType type) { + switch (type) { + case Diagnostics::CT_MatcherConstruct: + return "Error building matcher $0."; + case Diagnostics::CT_MatcherArg: + return "Error parsing argument $0 for matcher $1."; + } + llvm_unreachable("Unknown ContextType value."); +} + +static llvm::StringRef errorTypeToFormatString(Diagnostics::ErrorType type) { + switch (type) { + case Diagnostics::ET_RegistryMatcherNotFound: + return "Matcher not found: $0"; + case Diagnostics::ET_RegistryWrongArgCount: + return "Incorrect argument count. (Expected = $0) != (Actual = $1)"; + case Diagnostics::ET_RegistryWrongArgType: + return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)"; + case Diagnostics::ET_RegistryValueNotFound: + return "Value not found: $0"; + + case Diagnostics::ET_ParserStringError: + return "Error parsing string token: <$0>"; + case Diagnostics::ET_ParserNoOpenParen: + return "Error parsing matcher. Found token <$0> while looking for '('."; + case Diagnostics::ET_ParserNoCloseParen: + return "Error parsing matcher. Found end-of-code while looking for ')'."; + case Diagnostics::ET_ParserNoComma: + return "Error parsing matcher. Found token <$0> while looking for ','."; + case Diagnostics::ET_ParserNoCode: + return "End of code found while looking for token."; + case Diagnostics::ET_ParserNotAMatcher: + return "Input value is not a matcher expression."; + case Diagnostics::ET_ParserInvalidToken: + return "Invalid token <$0> found when looking for a value."; + case Diagnostics::ET_ParserTrailingCode: + return "Unexpected end of code."; + case Diagnostics::ET_ParserOverloadedType: + return "Input value has unresolved overloaded type: $0"; + case Diagnostics::ET_ParserFailedToBuildMatcher: + return "Failed to build matcher: $0."; + + case Diagnostics::ET_None: + return ""; + } + llvm_unreachable("Unknown ErrorType value."); +} + +static void formatErrorString(llvm::StringRef formatString, + llvm::ArrayRef args, + llvm::raw_ostream &OS) { + while (!formatString.empty()) { + std::pair pieces = + formatString.split("$"); + OS << pieces.first.str(); + if (pieces.second.empty()) + break; + + const char next = pieces.second.front(); + formatString = pieces.second.drop_front(); + if (next >= '0' && next <= '9') { + const unsigned index = next - '0'; + if (index < args.size()) { + OS << args[index]; + } else { + OS << ""; + } + } + } +} + +static void maybeAddLineAndColumn(SourceRange range, llvm::raw_ostream &OS) { + if (range.start.line > 0 && range.start.column > 0) { + OS << range.start.line << ":" << range.start.column << ": "; + } +} + +static void printContextFrameToStream(const Diagnostics::ContextFrame &frame, + llvm::raw_ostream &OS) { + maybeAddLineAndColumn(frame.range, OS); + formatErrorString(contextTypeToFormatString(frame.type), frame.args, OS); +} + +static void +printMessageToStream(const Diagnostics::ErrorContent::Message &message, + const llvm::Twine Prefix, llvm::raw_ostream &OS) { + maybeAddLineAndColumn(message.range, OS); + OS << Prefix; + formatErrorString(errorTypeToFormatString(message.type), message.args, OS); +} + +static void printErrorContentToStream(const Diagnostics::ErrorContent &content, + llvm::raw_ostream &OS) { + if (content.messages.size() == 1) { + printMessageToStream(content.messages[0], "", OS); + } else { + for (size_t i = 0, e = content.messages.size(); i != e; ++i) { + if (i != 0) + OS << "\n"; + printMessageToStream(content.messages[i], + "Candidate " + llvm::Twine(i + 1) + ": ", OS); + } + } +} + +void Diagnostics::print(llvm::raw_ostream &OS) const { + for (const ErrorContent &error : errorValues) { + if (&error != &errorValues.front()) + OS << "\n"; + printErrorContentToStream(error, OS); + } +} + +void Diagnostics::printFull(llvm::raw_ostream &OS) const { + for (const ErrorContent &error : errorValues) { + if (&error != &errorValues.front()) + OS << "\n"; + for (const ContextFrame &frame : error.contextStack) { + printContextFrameToStream(frame, OS); + OS << "\n"; + } + printErrorContentToStream(error, OS); + } +} + +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp new file mode 100644 index 0000000000000..bd69b746f76db --- /dev/null +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -0,0 +1,553 @@ +//===- MatcherParser.cpp - Matcher expression parser ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Recursive parser implementation for the matcher expression grammar. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/Parser.h" + +#include "llvm/Support/ManagedStatic.h" +#include + +namespace mlir::query::matcher { + +// Simple structure to hold information for one token from the parser. +struct Parser::TokenInfo { + // Different possible tokens. + enum TokenKind { + TK_Eof, + TK_NewLine, + TK_OpenParen, + TK_CloseParen, + TK_Comma, + TK_Period, + TK_Literal, + TK_Ident, + TK_InvalidChar, + TK_CodeCompletion, + TK_Error + }; + + TokenInfo() = default; + + // Method to set the kind and text of the token + void set(TokenKind newKind, llvm::StringRef newText) { + kind = newKind; + text = newText; + } + + llvm::StringRef text; + TokenKind kind = TK_Eof; + SourceRange range; + VariantValue value; +}; + +class Parser::CodeTokenizer { +public: + // Constructor with matcherCode and error + explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error) + : code(matcherCode), startOfLine(matcherCode), line(1), error(error) { + nextToken = getNextToken(); + } + + // Constructor with matcherCode, error, and codeCompletionOffset + CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error, + unsigned codeCompletionOffset) + : code(matcherCode), startOfLine(matcherCode), error(error), + codeCompletionLocation(matcherCode.data() + codeCompletionOffset) { + nextToken = getNextToken(); + } + + // Peek at next token without consuming it + const TokenInfo &peekNextToken() const { return nextToken; } + + // Consume and return the next token + TokenInfo consumeNextToken() { + TokenInfo thisToken = nextToken; + nextToken = getNextToken(); + return thisToken; + } + + // Skip any newline tokens + TokenInfo skipNewlines() { + while (nextToken.kind == TokenInfo::TK_NewLine) + nextToken = getNextToken(); + return nextToken; + } + + // Consume and return next token, ignoring newlines + TokenInfo consumeNextTokenIgnoreNewlines() { + skipNewlines(); + return nextToken.kind == TokenInfo::TK_Eof ? nextToken : consumeNextToken(); + } + + // Return kind of next token + TokenInfo::TokenKind nextTokenKind() const { return nextToken.kind; } + +private: + // Helper function to get the first character as a new StringRef and drop it + // from the original string + llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) { + assert(!str.empty()); + llvm::StringRef firstChar = str.substr(0, 1); + str = str.drop_front(); + return firstChar; + } + + // Get next token, consuming whitespaces and handling different token types + TokenInfo getNextToken() { + consumeWhitespace(); + TokenInfo result; + result.range.start = currentLocation(); + + // Code completion case + if (codeCompletionLocation && codeCompletionLocation <= code.data()) { + result.set(TokenInfo::TK_CodeCompletion, + llvm::StringRef(codeCompletionLocation, 0)); + codeCompletionLocation = nullptr; + return result; + } + + // End of file case + if (code.empty()) { + result.set(TokenInfo::TK_Eof, ""); + return result; + } + + // Switch to handle specific characters + switch (code[0]) { + case '#': + code = code.drop_until([](char c) { return c == '\n'; }); + return getNextToken(); + case ',': + result.set(TokenInfo::TK_Comma, firstCharacterAndDrop(code)); + break; + case '.': + result.set(TokenInfo::TK_Period, firstCharacterAndDrop(code)); + break; + case '\n': + ++line; + startOfLine = code.drop_front(); + result.set(TokenInfo::TK_NewLine, firstCharacterAndDrop(code)); + break; + case '(': + result.set(TokenInfo::TK_OpenParen, firstCharacterAndDrop(code)); + break; + case ')': + result.set(TokenInfo::TK_CloseParen, firstCharacterAndDrop(code)); + break; + case '"': + case '\'': + consumeStringLiteral(&result); + break; + default: + parseIdentifierOrInvalid(&result); + break; + } + + result.range.end = currentLocation(); + return result; + } + + // Consume a string literal, handle escape sequences and missing closing + // quote. + void consumeStringLiteral(TokenInfo *result) { + bool inEscape = false; + const char marker = code[0]; + for (size_t length = 1; length < code.size(); ++length) { + if (inEscape) { + inEscape = false; + continue; + } + if (code[length] == '\\') { + inEscape = true; + continue; + } + if (code[length] == marker) { + result->kind = TokenInfo::TK_Literal; + result->text = code.substr(0, length + 1); + result->value = code.substr(1, length - 1); + code = code.drop_front(length + 1); + return; + } + } + llvm::StringRef errorText = code; + code = code.drop_front(code.size()); + SourceRange range; + range.start = result->range.start; + range.end = currentLocation(); + error->addError(range, error->ET_ParserStringError) << errorText; + result->kind = TokenInfo::TK_Error; + } + + void parseIdentifierOrInvalid(TokenInfo *result) { + if (isalnum(code[0])) { + // Parse an identifier + size_t tokenLength = 1; + + while (true) { + // A code completion location in/immediately after an identifier will + // cause the portion of the identifier before the code completion + // location to become a code completion token. + if (codeCompletionLocation == code.data() + tokenLength) { + codeCompletionLocation = nullptr; + result->kind = TokenInfo::TK_CodeCompletion; + result->text = code.substr(0, tokenLength); + code = code.drop_front(tokenLength); + return; + } + if (tokenLength == code.size() || !(isalnum(code[tokenLength]))) + break; + ++tokenLength; + } + result->kind = TokenInfo::TK_Ident; + result->text = code.substr(0, tokenLength); + code = code.drop_front(tokenLength); + } else { + result->kind = TokenInfo::TK_InvalidChar; + result->text = code.substr(0, 1); + code = code.drop_front(1); + } + } + + // Consume all leading whitespace from code, except newlines + void consumeWhitespace() { + code = code.drop_while( + [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); }); + } + + // Returns the current location in the source code + SourceLocation currentLocation() { + SourceLocation location; + location.line = line; + location.column = code.data() - startOfLine.data() + 1; + return location; + } + + llvm::StringRef code; + llvm::StringRef startOfLine; + unsigned line = 1; + Diagnostics *error; + TokenInfo nextToken; + const char *codeCompletionLocation = nullptr; +}; + +Parser::Sema::~Sema() = default; + +std::vector Parser::Sema::getAcceptedCompletionTypes( + llvm::ArrayRef> context) { + return {}; +} + +std::vector +Parser::Sema::getMatcherCompletions(llvm::ArrayRef acceptedTypes) { + return {}; +} + +// Entry for the scope of a parser +struct Parser::ScopedContextEntry { + Parser *parser; + + ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) { + parser->contextStack.push_back({c, 0u}); + } + + ~ScopedContextEntry() { parser->contextStack.pop_back(); } + + void nextArg() { ++parser->contextStack.back().second; } +}; + +// Parse and validate expressions starting with an identifier. +// This function can parse named values and matchers. In case of failure, it +// will try to determine the user's intent to give an appropriate error message. +bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { + const TokenInfo nameToken = tokenizer->consumeNextToken(); + + if (tokenizer->nextTokenKind() != TokenInfo::TK_OpenParen) { + // Parse as a named value. + auto namedValue = + namedValues ? namedValues->lookup(nameToken.text) : VariantValue(); + + if (!namedValue.isMatcher()) { + error->addError(tokenizer->peekNextToken().range, + error->ET_ParserNotAMatcher); + return false; + } + + if (tokenizer->nextTokenKind() == TokenInfo::TK_NewLine) { + error->addError(tokenizer->peekNextToken().range, + error->ET_ParserNoOpenParen) + << "NewLine"; + return false; + } + + // If the syntax is correct and the name is not a matcher either, report + // an unknown named value. + if ((tokenizer->nextTokenKind() == TokenInfo::TK_Comma || + tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen || + tokenizer->nextTokenKind() == TokenInfo::TK_NewLine || + tokenizer->nextTokenKind() == TokenInfo::TK_Eof) && + !sema->lookupMatcherCtor(nameToken.text)) { + error->addError(nameToken.range, error->ET_RegistryValueNotFound) + << nameToken.text; + return false; + } + // Otherwise, fallback to the matcher parser. + } + + tokenizer->skipNewlines(); + + assert(nameToken.kind == TokenInfo::TK_Ident); + TokenInfo openToken = tokenizer->consumeNextToken(); + if (openToken.kind != TokenInfo::TK_OpenParen) { + error->addError(openToken.range, error->ET_ParserNoOpenParen) + << openToken.text; + return false; + } + + std::optional ctor = sema->lookupMatcherCtor(nameToken.text); + + // Parse as a matcher expression. + return parseMatcherExpressionImpl(nameToken, openToken, ctor, value); +} + +// Parse the arguments of a matcher +bool Parser::parseMatcherArgs(std::vector &args, MatcherCtor ctor, + const TokenInfo &nameToken, TokenInfo &endToken) { + ScopedContextEntry sce(this, ctor); + + while (tokenizer->nextTokenKind() != TokenInfo::TK_Eof) { + if (tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen) { + // end of args. + endToken = tokenizer->consumeNextToken(); + break; + } + + if (!args.empty()) { + // We must find a , token to continue. + TokenInfo commaToken = tokenizer->consumeNextToken(); + if (commaToken.kind != TokenInfo::TK_Comma) { + error->addError(commaToken.range, error->ET_ParserNoComma) + << commaToken.text; + return false; + } + } + + Diagnostics::Context ctx(Diagnostics::Context::MatcherArg, error, + nameToken.text, nameToken.range, args.size() + 1); + ParserValue argValue; + tokenizer->skipNewlines(); + + argValue.text = tokenizer->peekNextToken().text; + argValue.range = tokenizer->peekNextToken().range; + if (!parseExpressionImpl(&argValue.value)) { + return false; + } + + tokenizer->skipNewlines(); + args.push_back(argValue); + sce.nextArg(); + } + + return true; +} + +/// Parse and validate a matcher expression. +bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, + const TokenInfo &openToken, + std::optional ctor, + VariantValue *value) { + if (!ctor) { + error->addError(nameToken.range, error->ET_RegistryMatcherNotFound) + << nameToken.text; + // Do not return here. We need to continue to give completion suggestions. + } + + std::vector args; + TokenInfo endToken; + + tokenizer->skipNewlines(); + + if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) { + return false; + } + + if (!ctor) + return false; + // Merge the start and end infos. + Diagnostics::Context ctx(Diagnostics::Context::ConstructMatcher, error, + nameToken.text, nameToken.range); + SourceRange matcherRange = nameToken.range; + matcherRange.end = endToken.range.end; + VariantMatcher result = + sema->actOnMatcherExpression(*ctor, matcherRange, args, error); + if (result.isNull()) + return false; + *value = result; + return true; +} + +// If the prefix of this completion matches the completion token, add it to +// completions minus the prefix. +void Parser::addCompletion(const TokenInfo &compToken, + const MatcherCompletion &completion) { + if (llvm::StringRef(completion.typedText).startswith(compToken.text)) { + completions.emplace_back(completion.typedText.substr(compToken.text.size()), + completion.matcherDecl); + } +} + +std::vector +Parser::getNamedValueCompletions(ArrayRef acceptedTypes) { + if (!namedValues) + return {}; + + std::vector result; + for (const auto &entry : *namedValues) { + std::string decl = + (entry.getValue().getTypeAsString() + " " + entry.getKey()).str(); + result.emplace_back(entry.getKey(), decl); + } + return result; +} + +void Parser::addExpressionCompletions() { + const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines(); + assert(compToken.kind == TokenInfo::TK_CodeCompletion); + + // We cannot complete code if there is an invalid element on the context + // stack. + for (const auto &entry : contextStack) { + if (!entry.first) + return; + } + + auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack); + for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) { + addCompletion(compToken, completion); + } + + for (const auto &completion : getNamedValueCompletions(acceptedTypes)) { + addCompletion(compToken, completion); + } +} + +// Parse an +bool Parser::parseExpressionImpl(VariantValue *value) { + switch (tokenizer->nextTokenKind()) { + case TokenInfo::TK_Literal: + *value = tokenizer->consumeNextToken().value; + return true; + case TokenInfo::TK_Ident: + return parseIdentifierPrefixImpl(value); + case TokenInfo::TK_CodeCompletion: + addExpressionCompletions(); + return false; + case TokenInfo::TK_Eof: + error->addError(tokenizer->consumeNextToken().range, + error->ET_ParserNoCode); + return false; + + case TokenInfo::TK_Error: + // This error was already reported by the tokenizer. + return false; + case TokenInfo::TK_NewLine: + case TokenInfo::TK_OpenParen: + case TokenInfo::TK_CloseParen: + case TokenInfo::TK_Comma: + case TokenInfo::TK_Period: + case TokenInfo::TK_InvalidChar: + const TokenInfo token = tokenizer->consumeNextToken(); + error->addError(token.range, error->ET_ParserInvalidToken) + << (token.kind == TokenInfo::TK_NewLine ? "NewLine" : token.text); + return false; + } + + llvm_unreachable("Unknown token kind."); +} + +static llvm::ManagedStatic defaultRegistrySema; + +Parser::Parser(CodeTokenizer *tokenizer, Sema *sema, + const NamedValueMap *namedValues, Diagnostics *error) + : tokenizer(tokenizer), sema(sema ? sema : &*defaultRegistrySema), + namedValues(namedValues), error(error) {} + +Parser::RegistrySema::~RegistrySema() = default; + +std::optional +Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) { + return Registry::lookupMatcherCtor(matcherName); +} + +VariantMatcher Parser::RegistrySema::actOnMatcherExpression( + MatcherCtor ctor, SourceRange nameRange, ArrayRef args, + Diagnostics *error) { + return Registry::constructMatcher(ctor, nameRange, args, error); +} + +std::vector Parser::RegistrySema::getAcceptedCompletionTypes( + ArrayRef> context) { + return Registry::getAcceptedCompletionTypes(context); +} + +std::vector +Parser::RegistrySema::getMatcherCompletions(ArrayRef acceptedTypes) { + return Registry::getMatcherCompletions(acceptedTypes); +} + +bool Parser::parseExpression(llvm::StringRef &code, Sema *sema, + const NamedValueMap *namedValues, + VariantValue *value, Diagnostics *error) { + CodeTokenizer tokenizer(code, error); + Parser parser(&tokenizer, sema, namedValues, error); + if (!parser.parseExpressionImpl(value)) + return false; + auto nextToken = tokenizer.peekNextToken(); + if (nextToken.kind != TokenInfo::TK_Eof && + nextToken.kind != TokenInfo::TK_NewLine) { + error->addError(tokenizer.peekNextToken().range, + error->ET_ParserTrailingCode); + return false; + } + return true; +} + +std::vector +Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset, + Sema *sema, const NamedValueMap *namedValues) { + Diagnostics error; + CodeTokenizer tokenizer(code, &error, completionOffset); + Parser parser(&tokenizer, sema, namedValues, &error); + VariantValue dummy; + parser.parseExpressionImpl(&dummy); + + return parser.completions; +} + +std::optional +Parser::parseMatcherExpression(llvm::StringRef &code, Sema *sema, + const NamedValueMap *namedValues, + Diagnostics *error) { + VariantValue value; + if (!parseExpression(code, sema, namedValues, &value, error)) + return std::nullopt; + if (!value.isMatcher()) { + error->addError(SourceRange(), error->ET_ParserNotAMatcher); + return std::nullopt; + } + std::optional result = value.getMatcher().getDynMatcher(); + if (!result) { + error->addError(SourceRange(), error->ET_ParserOverloadedType) + << value.getTypeAsString(); + } + return result; +} + +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/Registry.cpp b/mlir/lib/Query/Matcher/Registry.cpp new file mode 100644 index 0000000000000..3c3fed8bd1059 --- /dev/null +++ b/mlir/lib/Query/Matcher/Registry.cpp @@ -0,0 +1,171 @@ +//===- MatcherRegistry.cpp - Matcher registry -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registry map populated at static initialization time. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/Registry.h" + +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/ManagedStatic.h" +#include +#include + +namespace mlir::query::matcher { +namespace { + +using ConstructorMap = + llvm::StringMap>; + +// This is needed because these matchers are defined as overloaded functions. +using IsConstantOp = detail::constant_op_matcher(); +using HasOpAttrName = detail::AttrOpMatcher(StringRef); +using HasOpName = detail::NameOpMatcher(StringRef); + +class RegistryMaps { +public: + RegistryMaps(); + ~RegistryMaps(); + + const ConstructorMap &constructors() const { return constructorMap; } + +private: + void registerMatcher(llvm::StringRef matcherName, + std::unique_ptr callback); + + ConstructorMap constructorMap; +}; + +} // namespace + +void RegistryMaps::registerMatcher( + llvm::StringRef matcherName, + std::unique_ptr callback) { + assert(!constructorMap.contains(matcherName)); + constructorMap[matcherName] = std::move(callback); +} + +// Generate a registry map with all the known matchers. +RegistryMaps::RegistryMaps() { + auto registerOpMatcher = [&](const std::string &name, auto matcher) { + registerMatcher(name, internal::makeMatcherAutoMarshall(matcher, name)); + }; + + // Register matchers using the template function (added in alphabetical order + // for consistency) + registerOpMatcher("hasOpAttrName", static_cast(m_Attr)); + registerOpMatcher("hasOpName", static_cast(m_Op)); + registerOpMatcher("isConstantOp", static_cast(m_Constant)); + registerOpMatcher("isNegInfFloat", m_NegInfFloat); + registerOpMatcher("isNegZeroFloat", m_NegZeroFloat); + registerOpMatcher("isNonZero", m_NonZero); + registerOpMatcher("isOne", m_One); + registerOpMatcher("isOneFloat", m_OneFloat); + registerOpMatcher("isPosInfFloat", m_PosInfFloat); + registerOpMatcher("isPosZeroFloat", m_PosZeroFloat); + registerOpMatcher("isZero", m_Zero); + registerOpMatcher("isZeroFloat", m_AnyZeroFloat); +} + +RegistryMaps::~RegistryMaps() = default; + +static llvm::ManagedStatic registryData; + +std::optional +Registry::lookupMatcherCtor(llvm::StringRef matcherName) { + auto it = registryData->constructors().find(matcherName); + return it == registryData->constructors().end() ? std::optional() + : it->second.get(); +} + +std::vector Registry::getAcceptedCompletionTypes( + ArrayRef> context) { + // Starting with the above seed of acceptable top-level matcher types, compute + // the acceptable type set for the argument indicated by each context element. + std::set typeSet; + typeSet.insert(ArgKind(ArgKind::AK_Matcher)); + + for (const auto &ctxEntry : context) { + MatcherCtor ctor = ctxEntry.first; + unsigned argNumber = ctxEntry.second; + std::vector nextTypeSet; + + if (argNumber < ctor->getNumArgs()) + ctor->getArgKinds(argNumber, nextTypeSet); + + typeSet.insert(nextTypeSet.begin(), nextTypeSet.end()); + } + + return std::vector(typeSet.begin(), typeSet.end()); +} + +std::vector +Registry::getMatcherCompletions(ArrayRef acceptedTypes) { + std::vector completions; + + // Search the registry for acceptable matchers. + for (const auto &m : registryData->constructors()) { + const internal::MatcherDescriptor &matcher = *m.getValue(); + StringRef name = m.getKey(); + + unsigned numArgs = matcher.getNumArgs(); + std::vector> argKinds(numArgs); + + for (const ArgKind &kind : acceptedTypes) { + if (kind.getArgKind() != kind.AK_Matcher) + continue; + + for (unsigned arg = 0; arg != numArgs; ++arg) + matcher.getArgKinds(arg, argKinds[arg]); + } + + std::string decl; + llvm::raw_string_ostream OS(decl); + + std::string typedText = std::string(name); + OS << "Matcher: " << name << "("; + + for (const std::vector &arg : argKinds) { + if (&arg != &argKinds[0]) + OS << ", "; + + bool firstArgKind = true; + // Two steps. First all non-matchers, then matchers only. + for (const ArgKind &argKind : arg) { + if (!firstArgKind) + OS << "|"; + + firstArgKind = false; + OS << argKind.asString(); + } + } + + OS << ")"; + typedText += "("; + + if (argKinds.empty()) + typedText += ")"; + else if (argKinds[0][0].getArgKind() == ArgKind::AK_String) + typedText += "\""; + + completions.emplace_back(typedText, OS.str()); + } + + return completions; +} + +VariantMatcher Registry::constructMatcher(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error) { + return ctor->create(nameRange, args, error); +} + +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp new file mode 100644 index 0000000000000..77c330450e10f --- /dev/null +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -0,0 +1,139 @@ +//===--- MatcherVariantvalue.cpp --------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/VariantValue.h" + +namespace mlir::query::matcher { + +std::string ArgKind::asString() const { + switch (getArgKind()) { + case AK_String: + return "String"; + case AK_Matcher: + return "Matcher"; + } + llvm_unreachable("Unhandled ArgKind"); +} + +VariantMatcher::Payload::~Payload() = default; + +class VariantMatcher::SinglePayload : public VariantMatcher::Payload { +public: + explicit SinglePayload(DynMatcher matcher) : matcher(std::move(matcher)) {} + + std::optional getDynMatcher() const override { return matcher; } + + std::string getTypeAsString() const override { return "Matcher"; } + +private: + DynMatcher matcher; +}; + +VariantMatcher::VariantMatcher() = default; + +VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) { + return VariantMatcher(std::make_shared(std::move(matcher))); +} + +std::optional VariantMatcher::getDynMatcher() const { + return value ? value->getDynMatcher() : std::nullopt; +} + +void VariantMatcher::reset() { value.reset(); } + +std::string VariantMatcher::getTypeAsString() const { return ""; } + +VariantValue::VariantValue(const VariantValue &other) : type(VT_Nothing) { + *this = other; +} + +VariantValue::VariantValue(const StringRef string) : type(VT_String) { + value.String = new StringRef(string); +} + +VariantValue::VariantValue(const VariantMatcher &matcher) : type(VT_Matcher) { + value.Matcher = new VariantMatcher(matcher); +} + +VariantValue::~VariantValue() { reset(); } + +VariantValue &VariantValue::operator=(const VariantValue &other) { + if (this == &other) + return *this; + reset(); + switch (other.type) { + case VT_String: + setString(other.getString()); + break; + case VT_Matcher: + setMatcher(other.getMatcher()); + break; + case VT_Nothing: + type = VT_Nothing; + break; + } + return *this; +} + +void VariantValue::reset() { + switch (type) { + case VT_String: + delete value.String; + break; + case VT_Matcher: + delete value.Matcher; + break; + // Cases that do nothing. + case VT_Nothing: + break; + } + type = VT_Nothing; +} + +bool VariantValue::isString() const { return type == VT_String; } + +const StringRef &VariantValue::getString() const { + assert(isString()); + return *value.String; +} + +void VariantValue::setString(const StringRef &newValue) { + reset(); + type = VT_String; + value.String = new StringRef(newValue); +} + +bool VariantValue::isMatcher() const { return type == VT_Matcher; } + +const VariantMatcher &VariantValue::getMatcher() const { + assert(isMatcher()); + return *value.Matcher; +} + +void VariantValue::setMatcher(const VariantMatcher &newValue) { + reset(); + type = VT_Matcher; + value.Matcher = new VariantMatcher(newValue); +} + +std::string VariantValue::getTypeAsString() const { + switch (type) { + case VT_String: + return "String"; + case VT_Matcher: + return "Matcher"; + case VT_Nothing: + return "Nothing"; + } + llvm_unreachable("Invalid Type"); +} + +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp new file mode 100644 index 0000000000000..7d9323ab33180 --- /dev/null +++ b/mlir/lib/Query/Query.cpp @@ -0,0 +1,59 @@ +//===---- Query.cpp - mlir-query query --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Query.h" +#include "mlir/Query/Matcher/MatchFinder.h" +#include "mlir/Query/QuerySession.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::query { + +static void printMatch(llvm::raw_ostream &OS, QuerySession &QS, Operation *op, + std::string binding) { + auto fileLoc = op->getLoc()->findInstanceOf(); + auto smloc = QS.sourceMgr->FindLocForLineAndColumn( + QS.bufferId, fileLoc.getLine(), fileLoc.getColumn()); + QS.sourceMgr->PrintMessage(OS, smloc, llvm::SourceMgr::DK_Note, + "\"" + binding + "\" binds here"); +} + +Query::~Query() {} + +bool InvalidQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + OS << errStr << "\n"; + return false; +} + +bool NoOpQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + return true; +} + +bool HelpQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + OS << "Available commands:\n\n" + " match MATCHER, m MATCHER " + "Match the mlir against the given matcher.\n\n"; + return true; +} + +bool MatchQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + int matchCount = 0; + std::vector matches = + matcher::MatchFinder().getMatches(QS.rootOp, matcher); + OS << "\n"; + for (Operation *op : matches) { + OS << "Match #" << ++matchCount << ":\n\n"; + // Placeholder "root" binding for the initial draft. + printMatch(OS, QS, op, "root"); + } + OS << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n"); + + return true; +} + +} // namespace mlir::query diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp new file mode 100644 index 0000000000000..a9a25166772ce --- /dev/null +++ b/mlir/lib/Query/QueryParser.cpp @@ -0,0 +1,208 @@ +//===---- QueryParser.cpp - mlir-query command parser -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/QueryParser.h" +#include "llvm/ADT/StringSwitch.h" + +namespace mlir::query { + +// Lex any amount of whitespace followed by a "word" (any sequence of +// non-whitespace characters) from the start of region [Begin,End). If no word +// is found before End, return StringRef(). Begin is adjusted to exclude the +// lexed region. +llvm::StringRef QueryParser::lexWord() { + line = line.drop_while([](char c) { + // Don't trim newlines. + return llvm::StringRef(" \t\v\f\r").contains(c); + }); + + if (line.empty()) + // Even though the line is empty, it contains a pointer and + // a (zero) length. The pointer is used in the LexOrCompleteWord + // code completion. + return line; + + llvm::StringRef word; + if (line.front() == '#') { + word = line.substr(0, 1); + } else { + word = line.take_until([](char c) { + // Don't trim newlines. + return llvm::StringRef(" \t\v\f\r").contains(c); + }); + } + + line = line.drop_front(word.size()); + return word; +} + +// This is the StringSwitch-alike used by lexOrCompleteWord below. See that +// function for details. +template +struct QueryParser::LexOrCompleteWord { + llvm::StringRef word; + llvm::StringSwitch stringSwitch; + + QueryParser *queryParser; + // Set to the completion point offset in word, or StringRef::npos if + // completion point not in word. + size_t wordCompletionPos; + + // Lexes a word and stores it in word. Returns a LexOrCompleteword object + // that can be used like a llvm::StringSwitch, but adds cases as possible + // completions if the lexed word contains the completion point. + LexOrCompleteWord(QueryParser *queryParser, llvm::StringRef &outWord) + : word(queryParser->lexWord()), stringSwitch(word), + queryParser(queryParser), wordCompletionPos(llvm::StringRef::npos) { + outWord = word; + if (queryParser->completionPos && + queryParser->completionPos <= word.data() + word.size()) { + if (queryParser->completionPos < word.data()) + wordCompletionPos = 0; + else + wordCompletionPos = queryParser->completionPos - word.data(); + } + } + + LexOrCompleteWord &Case(llvm::StringLiteral caseStr, const T &value, + bool isCompletion = true) { + + if (wordCompletionPos == llvm::StringRef::npos) + stringSwitch.Case(caseStr, value); + else if (caseStr.size() != 0 && isCompletion && + wordCompletionPos <= caseStr.size() && + caseStr.substr(0, wordCompletionPos) == + word.substr(0, wordCompletionPos)) { + + queryParser->completions.push_back(llvm::LineEditor::Completion( + (caseStr.substr(wordCompletionPos) + " ").str(), + std::string(caseStr))); + } + return *this; + } + + T Default(T value) { return stringSwitch.Default(value); } +}; + +QueryRef QueryParser::endQuery(QueryRef queryRef) { + llvm::StringRef extra = line; + llvm::StringRef extraTrimmed = extra.drop_while( + [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); }); + + if ((!extraTrimmed.empty() && extraTrimmed[0] == '\n') || + (extraTrimmed.size() >= 2 && extraTrimmed[0] == '\r' && + extraTrimmed[1] == '\n')) + queryRef->remainingContent = extra; + else { + llvm::StringRef trailingWord = lexWord(); + if (!trailingWord.empty() && trailingWord.front() == '#') { + line = line.drop_until([](char c) { return c == '\n'; }); + line = line.drop_while([](char c) { return c == '\n'; }); + return endQuery(queryRef); + } + if (!trailingWord.empty()) { + return new InvalidQuery("unexpected extra input: '" + extra + "'"); + } + } + return queryRef; +} + +namespace { + +enum ParsedQueryKind { + PQK_Invalid, + PQK_Comment, + PQK_NoOp, + PQK_Help, + PQK_Match, +}; + +QueryRef makeInvalidQueryFromDiagnostics(const matcher::Diagnostics &diag) { + std::string errStr; + llvm::raw_string_ostream OS(errStr); + diag.print(OS); + return new InvalidQuery(OS.str()); +} +} // namespace + +QueryRef QueryParser::completeMatcherExpression() { + std::vector comps = + matcher::Parser::completeExpression(line, completionPos - line.begin(), + nullptr, &QS.namedValues); + for (const auto &comp : comps) { + completions.emplace_back(comp.typedText, comp.matcherDecl); + } + return QueryRef(); +} + +QueryRef QueryParser::doParse() { + + llvm::StringRef commandStr; + ParsedQueryKind qKind = LexOrCompleteWord(this, commandStr) + .Case("", PQK_NoOp) + .Case("#", PQK_Comment, /*isCompletion=*/false) + .Case("help", PQK_Help) + .Case("m", PQK_Match, /*isCompletion=*/false) + .Case("match", PQK_Match) + .Default(PQK_Invalid); + + switch (qKind) { + case PQK_Comment: + case PQK_NoOp: + line = line.drop_until([](char c) { return c == '\n'; }); + line = line.drop_while([](char c) { return c == '\n'; }); + if (line.empty()) + return new NoOpQuery; + return doParse(); + + case PQK_Help: + return endQuery(new HelpQuery); + + case PQK_Match: { + if (completionPos) { + return completeMatcherExpression(); + } + + matcher::Diagnostics diag; + auto matcherSource = line.ltrim(); + auto origMatcherSource = matcherSource; + std::optional matcher = + matcher::Parser::parseMatcherExpression(matcherSource, nullptr, + &QS.namedValues, &diag); + if (!matcher) { + return makeInvalidQueryFromDiagnostics(diag); + } + auto actualSource = origMatcherSource.slice(0, origMatcherSource.size() - + matcherSource.size()); + auto *Q = new MatchQuery(actualSource, *matcher); + Q->remainingContent = matcherSource; + return Q; + } + + case PQK_Invalid: + return new InvalidQuery("unknown command: " + commandStr); + } + + llvm_unreachable("Invalid query kind"); +} + +QueryRef QueryParser::parse(llvm::StringRef line, const QuerySession &QS) { + return QueryParser(line, QS).doParse(); +} + +std::vector +QueryParser::complete(llvm::StringRef line, size_t pos, + const QuerySession &QS) { + QueryParser queryParser(line, QS); + queryParser.completionPos = line.data() + pos; + + queryParser.doParse(); + return queryParser.completions; +} + +} // namespace mlir::query diff --git a/mlir/lib/Tools/CMakeLists.txt b/mlir/lib/Tools/CMakeLists.txt index 6175a1ce5f8d1..01270fa4b0fc3 100644 --- a/mlir/lib/Tools/CMakeLists.txt +++ b/mlir/lib/Tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(lsp-server-support) add_subdirectory(mlir-lsp-server) add_subdirectory(mlir-opt) add_subdirectory(mlir-pdll-lsp-server) +add_subdirectory(mlir-query) add_subdirectory(mlir-reduce) add_subdirectory(mlir-tblgen) add_subdirectory(mlir-translate) diff --git a/mlir/lib/Tools/mlir-query/CMakeLists.txt b/mlir/lib/Tools/mlir-query/CMakeLists.txt new file mode 100644 index 0000000000000..b81b02d42bfca --- /dev/null +++ b/mlir/lib/Tools/mlir-query/CMakeLists.txt @@ -0,0 +1,13 @@ +set(LLVM_LINK_COMPONENTS + lineeditor + ) + +add_mlir_library(MLIRQueryLib + MlirQueryMain.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-query + + LINK_LIBS PUBLIC + MLIRQuery + ) diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp new file mode 100644 index 0000000000000..7f8151d94c4d0 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp @@ -0,0 +1,110 @@ +//===- MlirQueryMain.cpp - MLIR Query main ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the general framework of the MLIR query tool. It +// parses the command line arguments, parses the MLIR file and outputs the query +// results. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/mlir-query/MlirQueryMain.h" +#include "mlir/Query/QueryParser.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/ParseUtilities.h" +#include "llvm/LineEditor/LineEditor.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" + +//===----------------------------------------------------------------------===// +// Query Parser +//===----------------------------------------------------------------------===// + +mlir::LogicalResult mlir::mlirQueryMain(int argc, char **argv, + MLIRContext &context) { + // Override the default '-h' and use the default PrintHelpMessage() which + // won't print options in categories. + static llvm::cl::opt help("h", llvm::cl::desc("Alias for -help"), + llvm::cl::Hidden); + + static llvm::cl::OptionCategory mlirQueryCategory("mlir-query options"); + + static llvm::cl::list commands( + "c", llvm::cl::desc("Specify command to run"), + llvm::cl::value_desc("command"), llvm::cl::cat(mlirQueryCategory)); + + static llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::cat(mlirQueryCategory)); + + static llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; + + static llvm::cl::opt allowUnregisteredDialects( + "allow-unregistered-dialect", + llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::init(false)); + + llvm::cl::HideUnrelatedOptions(mlirQueryCategory); + + llvm::InitLLVM y(argc, argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR test case query tool.\n"); + + if (help) { + llvm::cl::PrintHelpMessage(); + return success(); + } + + // Set up the input file. + std::string errorMessage; + auto file = openInputFile(inputFilename, &errorMessage); + if (!file) { + return failure(); + } + + auto sourceMgr = std::make_shared(); + auto bufferId = sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc()); + + context.allowUnregisteredDialects(allowUnregisteredDialects); + + // Parse the input MLIR file. + OwningOpRef opRef = + parseSourceFileForTool(sourceMgr, &context, !noImplicitModule); + if (!opRef) + return failure(); + + mlir::query::QuerySession QS(opRef.get(), sourceMgr, bufferId); + if (!commands.empty()) { + for (auto &command : commands) { + mlir::query::QueryRef queryRef = + mlir::query::QueryParser::parse(command, QS); + if (!queryRef->run(llvm::outs(), QS)) + return failure(); + } + } else { + llvm::LineEditor LE("mlir-query"); + LE.setListCompleter([&QS](StringRef line, size_t pos) { + return mlir::query::QueryParser::complete(line, pos, QS); + }); + while (std::optional line = LE.readLine()) { + mlir::query::QueryRef queryRef = + mlir::query::QueryParser::parse(*line, QS); + queryRef->run(llvm::outs(), QS); + llvm::outs().flush(); + if (QS.terminate) + break; + } + } + + return success(); +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index bf143d036c2f6..6fc9ae0f3fc58 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -104,6 +104,7 @@ set(MLIR_TEST_DEPENDS mlir-pdll-lsp-server mlir-opt mlir-pdll + mlir-query mlir-reduce mlir-tblgen mlir-translate diff --git a/mlir/test/mlir-query/simple-test.mlir b/mlir/test/mlir-query/simple-test.mlir new file mode 100644 index 0000000000000..a4d006598767b --- /dev/null +++ b/mlir/test/mlir-query/simple-test.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-query %s -c "m isConstantOp()" | FileCheck %s + +// CHECK: {{.*}}.mlir:5:13: note: "root" binds here +func.func @simple1() { + %c1_i32 = arith.constant 1 : i32 + return +} + +// CHECK: {{.*}}.mlir:12:11: note: "root" binds here +// CHECK: {{.*}}.mlir:13:11: note: "root" binds here +func.func @simple2() { + %cst1 = arith.constant 1.0 : f32 + %cst2 = arith.constant 2.0 : f32 + %add = arith.addf %cst1, %cst2 : f32 + return +} diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt index e9a1e4d625172..a01f74f737e1b 100644 --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(mlir-lsp-server) add_subdirectory(mlir-opt) add_subdirectory(mlir-parser-fuzzer) add_subdirectory(mlir-pdll-lsp-server) +add_subdirectory(mlir-query) add_subdirectory(mlir-reduce) add_subdirectory(mlir-shlib) add_subdirectory(mlir-spirv-cpu-runner) diff --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt new file mode 100644 index 0000000000000..ef2e5a84b5569 --- /dev/null +++ b/mlir/tools/mlir-query/CMakeLists.txt @@ -0,0 +1,20 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + +if(MLIR_INCLUDE_TESTS) + set(test_libs + MLIRTestDialect + ) +endif() + +add_mlir_tool(mlir-query + mlir-query.cpp + ) +llvm_update_compile_flags(mlir-query) +target_link_libraries(mlir-query + PRIVATE + ${dialect_libs} + ${test_libs} + MLIRQueryLib + ) + +mlir_check_link_libraries(mlir-query) diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp new file mode 100644 index 0000000000000..1efbebad1bf34 --- /dev/null +++ b/mlir/tools/mlir-query/mlir-query.cpp @@ -0,0 +1,37 @@ +//===- mlir-query.cpp - MLIR Query Driver -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a command line utility that queries a file from/to MLIR using one +// of the registered queries. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Tools/mlir-query/MlirQueryMain.h" + +using namespace mlir; + +namespace test { +#ifdef MLIR_INCLUDE_TESTS +void registerTestDialect(DialectRegistry &); +#endif +} // namespace test + +int main(int argc, char **argv) { + + DialectRegistry registry; + registerAllDialects(registry); +#ifdef MLIR_INCLUDE_TESTS + test::registerTestDialect(registry); +#endif + MLIRContext context(registry); + + return failed(mlirQueryMain(argc, argv, context)); +} From da170ac10e0bc9e2ec2337a70a5ac2f6904ed7ba Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Mon, 31 Jul 2023 15:36:22 +0100 Subject: [PATCH 02/12] Fix uncatched missing brackets --- mlir/lib/Query/Matcher/Parser.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index bd69b746f76db..216b62caeed0d 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -378,6 +378,13 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, return false; } + // Check for the missing closing parenthesis + if (endToken.kind != TokenInfo::TK_CloseParen) { + error->addError(openToken.range, error->ET_ParserNoCloseParen) + << nameToken.text; + return false; + } + if (!ctor) return false; // Merge the start and end infos. From 9218906473dac06fe226f12bc4518c98b42de895 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Mon, 31 Jul 2023 15:54:38 +0100 Subject: [PATCH 03/12] Unexpose diagnostics functions --- mlir/include/mlir/Query/Matcher/Diagnostics.h | 40 ++++++++++++------- mlir/include/mlir/Query/Matcher/MatchFinder.h | 2 +- mlir/lib/Query/Matcher/Diagnostics.cpp | 18 ++++----- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Diagnostics.h b/mlir/include/mlir/Query/Matcher/Diagnostics.h index 35f29721b1f82..a07f6d4047091 100644 --- a/mlir/include/mlir/Query/Matcher/Diagnostics.h +++ b/mlir/include/mlir/Query/Matcher/Diagnostics.h @@ -38,9 +38,6 @@ struct SourceRange { // Diagnostics class to manage error messages. class Diagnostics { public: - // Parser context types. - enum ContextType { CT_MatcherArg, CT_MatcherConstruct }; - // All errors from the system. enum ErrorType { ET_None, @@ -92,6 +89,22 @@ class Diagnostics { Diagnostics *const error; }; + // Add an error message with the specified range and error type. + // Returns an ArgStream object to allow constructing the error message using + // the << operator. + ArgStream addError(SourceRange range, ErrorType error); + + // Print all error messages to the specified output stream. + void print(llvm::raw_ostream &OS) const; + + // Print the full error messages, including the context information, to the + // specified output stream. + void printFull(llvm::raw_ostream &OS) const; + +private: + // Parser context types. + enum ContextType { CT_MatcherArg, CT_MatcherConstruct }; + // Context for managing overloaded matcher construction. struct OverloadContext { // Construct an overload context with the given error. @@ -105,11 +118,6 @@ class Diagnostics { unsigned beginIndex{}; }; - // Add an error message with the specified range and error type. - // Returns an ArgStream object to allow constructing the error message using - // the << operator. - ArgStream addError(SourceRange range, ErrorType error); - // Information stored for one frame of the context. struct ContextFrame { ContextType type; @@ -131,14 +139,18 @@ class Diagnostics { // Get an array reference to the error contents. llvm::ArrayRef errors() const { return errorValues; } - // Print all error messages to the specified output stream. - void print(llvm::raw_ostream &OS) const; + llvm::StringRef contextTypeToFormatString(ContextType type) const; - // Print the full error messages, including the context information, to the - // specified output stream. - void printFull(llvm::raw_ostream &OS) const; + void printContextFrameToStream(const ContextFrame &frame, + llvm::raw_ostream &OS) const; + + void printMessageToStream(const ErrorContent::Message &message, + const llvm::Twine Prefix, + llvm::raw_ostream &OS) const; + + void printErrorContentToStream(const ErrorContent &content, + llvm::raw_ostream &OS) const; -private: // Push a new context frame onto the context stack with the specified type and // range. ArgStream pushContextFrame(ContextType type, SourceRange range); diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h index 5a87e45310920..70174052aaf89 100644 --- a/mlir/include/mlir/Query/Matcher/MatchFinder.h +++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h @@ -39,4 +39,4 @@ class MatchFinder { } // namespace mlir::query::matcher -#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H \ No newline at end of file +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp index aa9685ee1e436..c7d69af026b08 100644 --- a/mlir/lib/Query/Matcher/Diagnostics.cpp +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -72,8 +72,8 @@ Diagnostics::ArgStream Diagnostics::addError(SourceRange range, return ArgStream(&last.messages.back().args); } -static llvm::StringRef -contextTypeToFormatString(Diagnostics::ContextType type) { +llvm::StringRef +Diagnostics::contextTypeToFormatString(Diagnostics::ContextType type) const { switch (type) { case Diagnostics::CT_MatcherConstruct: return "Error building matcher $0."; @@ -150,22 +150,22 @@ static void maybeAddLineAndColumn(SourceRange range, llvm::raw_ostream &OS) { } } -static void printContextFrameToStream(const Diagnostics::ContextFrame &frame, - llvm::raw_ostream &OS) { +void Diagnostics::printContextFrameToStream( + const Diagnostics::ContextFrame &frame, llvm::raw_ostream &OS) const { maybeAddLineAndColumn(frame.range, OS); formatErrorString(contextTypeToFormatString(frame.type), frame.args, OS); } -static void -printMessageToStream(const Diagnostics::ErrorContent::Message &message, - const llvm::Twine Prefix, llvm::raw_ostream &OS) { +void Diagnostics::printMessageToStream( + const Diagnostics::ErrorContent::Message &message, const llvm::Twine Prefix, + llvm::raw_ostream &OS) const { maybeAddLineAndColumn(message.range, OS); OS << Prefix; formatErrorString(errorTypeToFormatString(message.type), message.args, OS); } -static void printErrorContentToStream(const Diagnostics::ErrorContent &content, - llvm::raw_ostream &OS) { +void Diagnostics::printErrorContentToStream( + const Diagnostics::ErrorContent &content, llvm::raw_ostream &OS) const { if (content.messages.size() == 1) { printMessageToStream(content.messages[0], "", OS); } else { From f90349189aa40ea30c30fd3554041311fa791819 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Mon, 31 Jul 2023 15:59:12 +0100 Subject: [PATCH 04/12] Progressively convert to enum classes --- mlir/include/mlir/Query/Matcher/Diagnostics.h | 2 +- mlir/lib/Query/Matcher/Diagnostics.cpp | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Diagnostics.h b/mlir/include/mlir/Query/Matcher/Diagnostics.h index a07f6d4047091..94f340cfcb8c5 100644 --- a/mlir/include/mlir/Query/Matcher/Diagnostics.h +++ b/mlir/include/mlir/Query/Matcher/Diagnostics.h @@ -103,7 +103,7 @@ class Diagnostics { private: // Parser context types. - enum ContextType { CT_MatcherArg, CT_MatcherConstruct }; + enum class ContextType { MatcherArg, MatcherConstruct }; // Context for managing overloaded matcher construction. struct OverloadContext { diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp index c7d69af026b08..a21657769193e 100644 --- a/mlir/lib/Query/Matcher/Diagnostics.cpp +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -23,14 +23,15 @@ Diagnostics::Context::Context(ConstructMatcherEnum, Diagnostics *error, llvm::StringRef matcherName, SourceRange matcherRange) : error(error) { - error->pushContextFrame(CT_MatcherConstruct, matcherRange) << matcherName; + error->pushContextFrame(ContextType::MatcherConstruct, matcherRange) + << matcherName; } Diagnostics::Context::Context(MatcherArgEnum, Diagnostics *error, llvm::StringRef matcherName, SourceRange matcherRange, int argnumber) : error(error) { - error->pushContextFrame(CT_MatcherArg, matcherRange) + error->pushContextFrame(ContextType::MatcherArg, matcherRange) << argnumber << matcherName; } @@ -75,9 +76,9 @@ Diagnostics::ArgStream Diagnostics::addError(SourceRange range, llvm::StringRef Diagnostics::contextTypeToFormatString(Diagnostics::ContextType type) const { switch (type) { - case Diagnostics::CT_MatcherConstruct: + case Diagnostics::ContextType::MatcherConstruct: return "Error building matcher $0."; - case Diagnostics::CT_MatcherArg: + case Diagnostics::ContextType::MatcherArg: return "Error parsing argument $0 for matcher $1."; } llvm_unreachable("Unknown ContextType value."); From a1116dbcb93d8e4ea6e898299e77dce024feb854 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Mon, 31 Jul 2023 17:01:54 +0100 Subject: [PATCH 05/12] Remove all the unnnecessary diagnostics stuff --- mlir/include/mlir/Query/Matcher/Diagnostics.h | 53 +---------- mlir/lib/Query/Matcher/Diagnostics.cpp | 89 ++----------------- mlir/lib/Query/Matcher/Parser.cpp | 4 - 3 files changed, 11 insertions(+), 135 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Diagnostics.h b/mlir/include/mlir/Query/Matcher/Diagnostics.h index 94f340cfcb8c5..a3da717b48ac1 100644 --- a/mlir/include/mlir/Query/Matcher/Diagnostics.h +++ b/mlir/include/mlir/Query/Matcher/Diagnostics.h @@ -75,20 +75,6 @@ class Diagnostics { std::vector *out; }; - // Context for constructing a matcher or parsing its argument. - struct Context { - enum ConstructMatcherEnum { ConstructMatcher }; - Context(ConstructMatcherEnum, Diagnostics *error, - llvm::StringRef matcherName, SourceRange matcherRange); - enum MatcherArgEnum { MatcherArg }; - Context(MatcherArgEnum, Diagnostics *error, llvm::StringRef matcherName, - SourceRange matcherRange, int argNumber); - ~Context(); - - private: - Diagnostics *const error; - }; - // Add an error message with the specified range and error type. // Returns an ArgStream object to allow constructing the error message using // the << operator. @@ -97,30 +83,9 @@ class Diagnostics { // Print all error messages to the specified output stream. void print(llvm::raw_ostream &OS) const; - // Print the full error messages, including the context information, to the - // specified output stream. - void printFull(llvm::raw_ostream &OS) const; - private: - // Parser context types. - enum class ContextType { MatcherArg, MatcherConstruct }; - - // Context for managing overloaded matcher construction. - struct OverloadContext { - // Construct an overload context with the given error. - OverloadContext(Diagnostics *error); - ~OverloadContext(); - // Revert all errors that occurred within this context. - void revertErrors(); - - private: - Diagnostics *const error; - unsigned beginIndex{}; - }; - // Information stored for one frame of the context. struct ContextFrame { - ContextType type; SourceRange range; std::vector args; }; @@ -139,21 +104,11 @@ class Diagnostics { // Get an array reference to the error contents. llvm::ArrayRef errors() const { return errorValues; } - llvm::StringRef contextTypeToFormatString(ContextType type) const; - - void printContextFrameToStream(const ContextFrame &frame, - llvm::raw_ostream &OS) const; - - void printMessageToStream(const ErrorContent::Message &message, - const llvm::Twine Prefix, - llvm::raw_ostream &OS) const; - - void printErrorContentToStream(const ErrorContent &content, - llvm::raw_ostream &OS) const; + void printMessage(const ErrorContent::Message &message, + const llvm::Twine Prefix, llvm::raw_ostream &OS) const; - // Push a new context frame onto the context stack with the specified type and - // range. - ArgStream pushContextFrame(ContextType type, SourceRange range); + void printErrorContent(const ErrorContent &content, + llvm::raw_ostream &OS) const; std::vector contextStack; std::vector errorValues; diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp index a21657769193e..67a59a4fe08fe 100644 --- a/mlir/lib/Query/Matcher/Diagnostics.cpp +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -10,52 +10,6 @@ namespace mlir::query::matcher { -Diagnostics::ArgStream Diagnostics::pushContextFrame(ContextType type, - SourceRange range) { - contextStack.emplace_back(); - ContextFrame &data = contextStack.back(); - data.type = type; - data.range = range; - return ArgStream(&data.args); -} - -Diagnostics::Context::Context(ConstructMatcherEnum, Diagnostics *error, - llvm::StringRef matcherName, - SourceRange matcherRange) - : error(error) { - error->pushContextFrame(ContextType::MatcherConstruct, matcherRange) - << matcherName; -} - -Diagnostics::Context::Context(MatcherArgEnum, Diagnostics *error, - llvm::StringRef matcherName, - SourceRange matcherRange, int argnumber) - : error(error) { - error->pushContextFrame(ContextType::MatcherArg, matcherRange) - << argnumber << matcherName; -} - -Diagnostics::Context::~Context() { error->contextStack.pop_back(); } - -Diagnostics::OverloadContext::OverloadContext(Diagnostics *error) - : error(error), beginIndex(error->errorValues.size()) {} - -Diagnostics::OverloadContext::~OverloadContext() { - // Merge all errors that happened while in this context. - if (beginIndex < error->errorValues.size()) { - Diagnostics::ErrorContent &dest = error->errorValues[beginIndex]; - for (size_t i = beginIndex + 1, e = error->errorValues.size(); i < e; ++i) { - dest.messages.push_back(error->errorValues[i].messages[0]); - } - error->errorValues.resize(beginIndex + 1); - } -} - -void Diagnostics::OverloadContext::revertErrors() { - // Revert the errors. - error->errorValues.resize(beginIndex); -} - Diagnostics::ArgStream & Diagnostics::ArgStream::operator<<(const llvm::Twine &arg) { out->push_back(arg.str()); @@ -73,17 +27,6 @@ Diagnostics::ArgStream Diagnostics::addError(SourceRange range, return ArgStream(&last.messages.back().args); } -llvm::StringRef -Diagnostics::contextTypeToFormatString(Diagnostics::ContextType type) const { - switch (type) { - case Diagnostics::ContextType::MatcherConstruct: - return "Error building matcher $0."; - case Diagnostics::ContextType::MatcherArg: - return "Error parsing argument $0 for matcher $1."; - } - llvm_unreachable("Unknown ContextType value."); -} - static llvm::StringRef errorTypeToFormatString(Diagnostics::ErrorType type) { switch (type) { case Diagnostics::ET_RegistryMatcherNotFound: @@ -151,13 +94,7 @@ static void maybeAddLineAndColumn(SourceRange range, llvm::raw_ostream &OS) { } } -void Diagnostics::printContextFrameToStream( - const Diagnostics::ContextFrame &frame, llvm::raw_ostream &OS) const { - maybeAddLineAndColumn(frame.range, OS); - formatErrorString(contextTypeToFormatString(frame.type), frame.args, OS); -} - -void Diagnostics::printMessageToStream( +void Diagnostics::printMessage( const Diagnostics::ErrorContent::Message &message, const llvm::Twine Prefix, llvm::raw_ostream &OS) const { maybeAddLineAndColumn(message.range, OS); @@ -165,16 +102,16 @@ void Diagnostics::printMessageToStream( formatErrorString(errorTypeToFormatString(message.type), message.args, OS); } -void Diagnostics::printErrorContentToStream( - const Diagnostics::ErrorContent &content, llvm::raw_ostream &OS) const { +void Diagnostics::printErrorContent(const Diagnostics::ErrorContent &content, + llvm::raw_ostream &OS) const { if (content.messages.size() == 1) { - printMessageToStream(content.messages[0], "", OS); + printMessage(content.messages[0], "", OS); } else { for (size_t i = 0, e = content.messages.size(); i != e; ++i) { if (i != 0) OS << "\n"; - printMessageToStream(content.messages[i], - "Candidate " + llvm::Twine(i + 1) + ": ", OS); + printMessage(content.messages[i], + "Candidate " + llvm::Twine(i + 1) + ": ", OS); } } } @@ -183,19 +120,7 @@ void Diagnostics::print(llvm::raw_ostream &OS) const { for (const ErrorContent &error : errorValues) { if (&error != &errorValues.front()) OS << "\n"; - printErrorContentToStream(error, OS); - } -} - -void Diagnostics::printFull(llvm::raw_ostream &OS) const { - for (const ErrorContent &error : errorValues) { - if (&error != &errorValues.front()) - OS << "\n"; - for (const ContextFrame &frame : error.contextStack) { - printContextFrameToStream(frame, OS); - OS << "\n"; - } - printErrorContentToStream(error, OS); + printErrorContent(error, OS); } } diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index 216b62caeed0d..438e35b6685e4 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -339,8 +339,6 @@ bool Parser::parseMatcherArgs(std::vector &args, MatcherCtor ctor, } } - Diagnostics::Context ctx(Diagnostics::Context::MatcherArg, error, - nameToken.text, nameToken.range, args.size() + 1); ParserValue argValue; tokenizer->skipNewlines(); @@ -388,8 +386,6 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, if (!ctor) return false; // Merge the start and end infos. - Diagnostics::Context ctx(Diagnostics::Context::ConstructMatcher, error, - nameToken.text, nameToken.range); SourceRange matcherRange = nameToken.range; matcherRange.end = endToken.range.end; VariantMatcher result = From afc6cc0319007616a76fe30783a6d9a9c6491585 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Mon, 31 Jul 2023 17:29:43 +0100 Subject: [PATCH 06/12] Convert to enum class --- mlir/include/mlir/Query/Matcher/Diagnostics.h | 32 +++++++++---------- mlir/include/mlir/Query/Matcher/Marshallers.h | 5 +-- mlir/lib/Query/Matcher/Diagnostics.cpp | 30 ++++++++--------- mlir/lib/Query/Matcher/Parser.cpp | 29 +++++++++-------- 4 files changed, 50 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Diagnostics.h b/mlir/include/mlir/Query/Matcher/Diagnostics.h index a3da717b48ac1..06c7a0029a196 100644 --- a/mlir/include/mlir/Query/Matcher/Diagnostics.h +++ b/mlir/include/mlir/Query/Matcher/Diagnostics.h @@ -39,26 +39,26 @@ struct SourceRange { class Diagnostics { public: // All errors from the system. - enum ErrorType { - ET_None, + enum class ErrorType { + None, // Parser Errors - ET_ParserFailedToBuildMatcher, - ET_ParserInvalidToken, - ET_ParserNoCloseParen, - ET_ParserNoCode, - ET_ParserNoComma, - ET_ParserNoOpenParen, - ET_ParserNotAMatcher, - ET_ParserOverloadedType, - ET_ParserStringError, - ET_ParserTrailingCode, + ParserFailedToBuildMatcher, + ParserInvalidToken, + ParserNoCloseParen, + ParserNoCode, + ParserNoComma, + ParserNoOpenParen, + ParserNotAMatcher, + ParserOverloadedType, + ParserStringError, + ParserTrailingCode, // Registry Errors - ET_RegistryMatcherNotFound, - ET_RegistryValueNotFound, - ET_RegistryWrongArgCount, - ET_RegistryWrongArgType + RegistryMatcherNotFound, + RegistryValueNotFound, + RegistryWrongArgCount, + RegistryWrongArgType }; // Helper stream class for constructing error messages. diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h index 14f6507041a68..bada7c12aedb4 100644 --- a/mlir/include/mlir/Query/Matcher/Marshallers.h +++ b/mlir/include/mlir/Query/Matcher/Marshallers.h @@ -124,7 +124,7 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor { inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, ArrayRef args, Diagnostics *error) { if (args.size() != expectedArgCount) { - error->addError(nameRange, error->ET_RegistryWrongArgCount) + error->addError(nameRange, Diagnostics::ErrorType::RegistryWrongArgCount) << expectedArgCount << args.size(); return false; } @@ -137,7 +137,8 @@ inline bool checkArgTypeAtIndex(StringRef matcherName, ArrayRef args, Diagnostics *error) { if (!ArgTypeTraits::hasCorrectType(args[Index].value)) { - error->addError(args[Index].range, error->ET_RegistryWrongArgType) + error->addError(args[Index].range, + Diagnostics::ErrorType::RegistryWrongArgType) << matcherName << Index + 1; return false; } diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp index 67a59a4fe08fe..f80650071e6bc 100644 --- a/mlir/lib/Query/Matcher/Diagnostics.cpp +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -29,37 +29,37 @@ Diagnostics::ArgStream Diagnostics::addError(SourceRange range, static llvm::StringRef errorTypeToFormatString(Diagnostics::ErrorType type) { switch (type) { - case Diagnostics::ET_RegistryMatcherNotFound: + case Diagnostics::ErrorType::RegistryMatcherNotFound: return "Matcher not found: $0"; - case Diagnostics::ET_RegistryWrongArgCount: + case Diagnostics::ErrorType::RegistryWrongArgCount: return "Incorrect argument count. (Expected = $0) != (Actual = $1)"; - case Diagnostics::ET_RegistryWrongArgType: + case Diagnostics::ErrorType::RegistryWrongArgType: return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)"; - case Diagnostics::ET_RegistryValueNotFound: + case Diagnostics::ErrorType::RegistryValueNotFound: return "Value not found: $0"; - case Diagnostics::ET_ParserStringError: + case Diagnostics::ErrorType::ParserStringError: return "Error parsing string token: <$0>"; - case Diagnostics::ET_ParserNoOpenParen: + case Diagnostics::ErrorType::ParserNoOpenParen: return "Error parsing matcher. Found token <$0> while looking for '('."; - case Diagnostics::ET_ParserNoCloseParen: + case Diagnostics::ErrorType::ParserNoCloseParen: return "Error parsing matcher. Found end-of-code while looking for ')'."; - case Diagnostics::ET_ParserNoComma: + case Diagnostics::ErrorType::ParserNoComma: return "Error parsing matcher. Found token <$0> while looking for ','."; - case Diagnostics::ET_ParserNoCode: + case Diagnostics::ErrorType::ParserNoCode: return "End of code found while looking for token."; - case Diagnostics::ET_ParserNotAMatcher: + case Diagnostics::ErrorType::ParserNotAMatcher: return "Input value is not a matcher expression."; - case Diagnostics::ET_ParserInvalidToken: + case Diagnostics::ErrorType::ParserInvalidToken: return "Invalid token <$0> found when looking for a value."; - case Diagnostics::ET_ParserTrailingCode: + case Diagnostics::ErrorType::ParserTrailingCode: return "Unexpected end of code."; - case Diagnostics::ET_ParserOverloadedType: + case Diagnostics::ErrorType::ParserOverloadedType: return "Input value has unresolved overloaded type: $0"; - case Diagnostics::ET_ParserFailedToBuildMatcher: + case Diagnostics::ErrorType::ParserFailedToBuildMatcher: return "Failed to build matcher: $0."; - case Diagnostics::ET_None: + case Diagnostics::ErrorType::None: return ""; } llvm_unreachable("Unknown ErrorType value."); diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index 438e35b6685e4..577ad0a6e98bb 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -182,7 +182,8 @@ class Parser::CodeTokenizer { SourceRange range; range.start = result->range.start; range.end = currentLocation(); - error->addError(range, error->ET_ParserStringError) << errorText; + error->addError(range, Diagnostics::ErrorType::ParserStringError) + << errorText; result->kind = TokenInfo::TK_Error; } @@ -276,13 +277,13 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { if (!namedValue.isMatcher()) { error->addError(tokenizer->peekNextToken().range, - error->ET_ParserNotAMatcher); + Diagnostics::ErrorType::ParserNotAMatcher); return false; } if (tokenizer->nextTokenKind() == TokenInfo::TK_NewLine) { error->addError(tokenizer->peekNextToken().range, - error->ET_ParserNoOpenParen) + Diagnostics::ErrorType::ParserNoOpenParen) << "NewLine"; return false; } @@ -294,7 +295,8 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { tokenizer->nextTokenKind() == TokenInfo::TK_NewLine || tokenizer->nextTokenKind() == TokenInfo::TK_Eof) && !sema->lookupMatcherCtor(nameToken.text)) { - error->addError(nameToken.range, error->ET_RegistryValueNotFound) + error->addError(nameToken.range, + Diagnostics::ErrorType::RegistryValueNotFound) << nameToken.text; return false; } @@ -306,7 +308,7 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { assert(nameToken.kind == TokenInfo::TK_Ident); TokenInfo openToken = tokenizer->consumeNextToken(); if (openToken.kind != TokenInfo::TK_OpenParen) { - error->addError(openToken.range, error->ET_ParserNoOpenParen) + error->addError(openToken.range, Diagnostics::ErrorType::ParserNoOpenParen) << openToken.text; return false; } @@ -333,7 +335,7 @@ bool Parser::parseMatcherArgs(std::vector &args, MatcherCtor ctor, // We must find a , token to continue. TokenInfo commaToken = tokenizer->consumeNextToken(); if (commaToken.kind != TokenInfo::TK_Comma) { - error->addError(commaToken.range, error->ET_ParserNoComma) + error->addError(commaToken.range, Diagnostics::ErrorType::ParserNoComma) << commaToken.text; return false; } @@ -362,7 +364,8 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, std::optional ctor, VariantValue *value) { if (!ctor) { - error->addError(nameToken.range, error->ET_RegistryMatcherNotFound) + error->addError(nameToken.range, + Diagnostics::ErrorType::RegistryMatcherNotFound) << nameToken.text; // Do not return here. We need to continue to give completion suggestions. } @@ -378,7 +381,7 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, // Check for the missing closing parenthesis if (endToken.kind != TokenInfo::TK_CloseParen) { - error->addError(openToken.range, error->ET_ParserNoCloseParen) + error->addError(openToken.range, Diagnostics::ErrorType::ParserNoCloseParen) << nameToken.text; return false; } @@ -454,7 +457,7 @@ bool Parser::parseExpressionImpl(VariantValue *value) { return false; case TokenInfo::TK_Eof: error->addError(tokenizer->consumeNextToken().range, - error->ET_ParserNoCode); + Diagnostics::ErrorType::ParserNoCode); return false; case TokenInfo::TK_Error: @@ -467,7 +470,7 @@ bool Parser::parseExpressionImpl(VariantValue *value) { case TokenInfo::TK_Period: case TokenInfo::TK_InvalidChar: const TokenInfo token = tokenizer->consumeNextToken(); - error->addError(token.range, error->ET_ParserInvalidToken) + error->addError(token.range, Diagnostics::ErrorType::ParserInvalidToken) << (token.kind == TokenInfo::TK_NewLine ? "NewLine" : token.text); return false; } @@ -516,7 +519,7 @@ bool Parser::parseExpression(llvm::StringRef &code, Sema *sema, if (nextToken.kind != TokenInfo::TK_Eof && nextToken.kind != TokenInfo::TK_NewLine) { error->addError(tokenizer.peekNextToken().range, - error->ET_ParserTrailingCode); + Diagnostics::ErrorType::ParserTrailingCode); return false; } return true; @@ -542,12 +545,12 @@ Parser::parseMatcherExpression(llvm::StringRef &code, Sema *sema, if (!parseExpression(code, sema, namedValues, &value, error)) return std::nullopt; if (!value.isMatcher()) { - error->addError(SourceRange(), error->ET_ParserNotAMatcher); + error->addError(SourceRange(), Diagnostics::ErrorType::ParserNotAMatcher); return std::nullopt; } std::optional result = value.getMatcher().getDynMatcher(); if (!result) { - error->addError(SourceRange(), error->ET_ParserOverloadedType) + error->addError(SourceRange(), Diagnostics::ErrorType::ParserOverloadedType) << value.getTypeAsString(); } return result; From 90d421d84f6620703f4824ea399daa3310df54aa Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Tue, 1 Aug 2023 16:15:17 +0100 Subject: [PATCH 07/12] Move internal details to internal namespace --- mlir/include/mlir/Query/Matcher/Diagnostics.h | 4 ++-- mlir/include/mlir/Query/Matcher/Parser.h | 4 ++-- mlir/include/mlir/Query/Matcher/Registry.h | 4 ++-- mlir/include/mlir/Query/Matcher/VariantValue.h | 2 +- mlir/lib/Query/Matcher/Diagnostics.cpp | 4 ++-- mlir/lib/Query/Matcher/Parser.cpp | 4 ++-- mlir/lib/Query/Matcher/Registry.cpp | 4 ++-- mlir/lib/Query/QueryParser.cpp | 13 +++++++------ 8 files changed, 20 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Diagnostics.h b/mlir/include/mlir/Query/Matcher/Diagnostics.h index 06c7a0029a196..3b09490409261 100644 --- a/mlir/include/mlir/Query/Matcher/Diagnostics.h +++ b/mlir/include/mlir/Query/Matcher/Diagnostics.h @@ -21,7 +21,7 @@ #include #include -namespace mlir::query::matcher { +namespace mlir::query::matcher::internal { // Represents the line and column numbers in a source query. struct SourceLocation { @@ -114,6 +114,6 @@ class Diagnostics { std::vector errorValues; }; -} // namespace mlir::query::matcher +} // namespace mlir::query::matcher::internal #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H diff --git a/mlir/include/mlir/Query/Matcher/Parser.h b/mlir/include/mlir/Query/Matcher/Parser.h index 232ab20d52189..0a1d0babbce8e 100644 --- a/mlir/include/mlir/Query/Matcher/Parser.h +++ b/mlir/include/mlir/Query/Matcher/Parser.h @@ -33,7 +33,7 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" -namespace mlir::query::matcher { +namespace mlir::query::matcher::internal { // Matcher expression parser. class Parser { @@ -169,6 +169,6 @@ class Parser { std::vector completions; }; -} // namespace mlir::query::matcher +} // namespace mlir::query::matcher::internal #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H diff --git a/mlir/include/mlir/Query/Matcher/Registry.h b/mlir/include/mlir/Query/Matcher/Registry.h index 4bfa1a0c1ab83..8b6f6e5586f0f 100644 --- a/mlir/include/mlir/Query/Matcher/Registry.h +++ b/mlir/include/mlir/Query/Matcher/Registry.h @@ -56,9 +56,9 @@ class Registry { getMatcherCompletions(ArrayRef acceptedTypes); static VariantMatcher constructMatcher(MatcherCtor ctor, - SourceRange nameRange, + internal::SourceRange nameRange, ArrayRef args, - Diagnostics *error); + internal::Diagnostics *error); }; } // namespace mlir::query::matcher diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h index 22182c17319f9..8c169aa025736 100644 --- a/mlir/include/mlir/Query/Matcher/VariantValue.h +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -132,7 +132,7 @@ class VariantValue { struct ParserValue { ParserValue() {} llvm::StringRef text; - SourceRange range; + internal::SourceRange range; VariantValue value; }; diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp index f80650071e6bc..b9e3ab7726015 100644 --- a/mlir/lib/Query/Matcher/Diagnostics.cpp +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -8,7 +8,7 @@ #include "mlir/Query/Matcher/Diagnostics.h" -namespace mlir::query::matcher { +namespace mlir::query::matcher::internal { Diagnostics::ArgStream & Diagnostics::ArgStream::operator<<(const llvm::Twine &arg) { @@ -124,4 +124,4 @@ void Diagnostics::print(llvm::raw_ostream &OS) const { } } -} // namespace mlir::query::matcher +} // namespace mlir::query::matcher::internal diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index 577ad0a6e98bb..10839463707f0 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -15,7 +15,7 @@ #include "llvm/Support/ManagedStatic.h" #include -namespace mlir::query::matcher { +namespace mlir::query::matcher::internal { // Simple structure to hold information for one token from the parser. struct Parser::TokenInfo { @@ -556,4 +556,4 @@ Parser::parseMatcherExpression(llvm::StringRef &code, Sema *sema, return result; } -} // namespace mlir::query::matcher +} // namespace mlir::query::matcher::internal diff --git a/mlir/lib/Query/Matcher/Registry.cpp b/mlir/lib/Query/Matcher/Registry.cpp index 3c3fed8bd1059..a56ebbdf18959 100644 --- a/mlir/lib/Query/Matcher/Registry.cpp +++ b/mlir/lib/Query/Matcher/Registry.cpp @@ -162,9 +162,9 @@ Registry::getMatcherCompletions(ArrayRef acceptedTypes) { } VariantMatcher Registry::constructMatcher(MatcherCtor ctor, - SourceRange nameRange, + internal::SourceRange nameRange, ArrayRef args, - Diagnostics *error) { + internal::Diagnostics *error) { return ctor->create(nameRange, args, error); } diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp index a9a25166772ce..60e6bbf1ca574 100644 --- a/mlir/lib/Query/QueryParser.cpp +++ b/mlir/lib/Query/QueryParser.cpp @@ -122,7 +122,8 @@ enum ParsedQueryKind { PQK_Match, }; -QueryRef makeInvalidQueryFromDiagnostics(const matcher::Diagnostics &diag) { +QueryRef +makeInvalidQueryFromDiagnostics(const matcher::internal::Diagnostics &diag) { std::string errStr; llvm::raw_string_ostream OS(errStr); diag.print(OS); @@ -132,8 +133,8 @@ QueryRef makeInvalidQueryFromDiagnostics(const matcher::Diagnostics &diag) { QueryRef QueryParser::completeMatcherExpression() { std::vector comps = - matcher::Parser::completeExpression(line, completionPos - line.begin(), - nullptr, &QS.namedValues); + matcher::internal::Parser::completeExpression( + line, completionPos - line.begin(), nullptr, &QS.namedValues); for (const auto &comp : comps) { completions.emplace_back(comp.typedText, comp.matcherDecl); } @@ -168,12 +169,12 @@ QueryRef QueryParser::doParse() { return completeMatcherExpression(); } - matcher::Diagnostics diag; + matcher::internal::Diagnostics diag; auto matcherSource = line.ltrim(); auto origMatcherSource = matcherSource; std::optional matcher = - matcher::Parser::parseMatcherExpression(matcherSource, nullptr, - &QS.namedValues, &diag); + matcher::internal::Parser::parseMatcherExpression( + matcherSource, nullptr, &QS.namedValues, &diag); if (!matcher) { return makeInvalidQueryFromDiagnostics(diag); } From a250a8d86728e274880da2e93b6fe7d84df62b3f Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Tue, 1 Aug 2023 16:45:43 +0100 Subject: [PATCH 08/12] Replace most enums with enum class --- mlir/include/mlir/Query/Matcher/Parser.h | 15 +++ .../include/mlir/Query/Matcher/VariantValue.h | 10 +- mlir/include/mlir/Query/Query.h | 26 +++-- mlir/lib/Query/Matcher/Parser.cpp | 101 ++++++++---------- mlir/lib/Query/Matcher/VariantValue.cpp | 38 +++---- mlir/lib/Query/QueryParser.cpp | 37 +++---- 6 files changed, 119 insertions(+), 108 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Parser.h b/mlir/include/mlir/Query/Matcher/Parser.h index 0a1d0babbce8e..b7c22f5ecfafd 100644 --- a/mlir/include/mlir/Query/Matcher/Parser.h +++ b/mlir/include/mlir/Query/Matcher/Parser.h @@ -38,6 +38,21 @@ namespace mlir::query::matcher::internal { // Matcher expression parser. class Parser { public: + // Different possible tokens. + enum class TokenKind { + Eof, + NewLine, + OpenParen, + CloseParen, + Comma, + Period, + Literal, + Ident, + InvalidChar, + CodeCompletion, + Error + }; + // Interface to connect the parser with the registry and more. The parser uses // the Sema instance passed into parseMatcherExpression() to handle all // matcher tokens. diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h index 8c169aa025736..121aebd111c0a 100644 --- a/mlir/include/mlir/Query/Matcher/VariantValue.h +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -85,7 +85,7 @@ class VariantMatcher { // - VariantMatcher class VariantValue { public: - VariantValue() : type(VT_Nothing) {} + VariantValue() : type(ValueType::Nothing) {} VariantValue(const VariantValue &other); ~VariantValue(); @@ -112,10 +112,10 @@ class VariantValue { void reset(); // All supported value types. - enum ValueType { - VT_Nothing, - VT_String, - VT_Matcher, + enum class ValueType { + Nothing, + String, + Matcher, }; // All supported value types. diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h index 77cda9853b69a..9166769246443 100644 --- a/mlir/include/mlir/Query/Query.h +++ b/mlir/include/mlir/Query/Query.h @@ -16,7 +16,7 @@ namespace mlir::query { -enum QueryKind { QK_Invalid, QK_NoOp, QK_Help, QK_Match }; +enum class QueryKind { Invalid, NoOp, Help, Match }; class QuerySession; @@ -37,41 +37,49 @@ typedef llvm::IntrusiveRefCntPtr QueryRef; // Any query which resulted in a parse error. The error message is in ErrStr. struct InvalidQuery : Query { InvalidQuery(const llvm::Twine &errStr) - : Query(QK_Invalid), errStr(errStr.str()) {} + : Query(QueryKind::Invalid), errStr(errStr.str()) {} bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; std::string errStr; - static bool classof(const Query *query) { return query->kind == QK_Invalid; } + static bool classof(const Query *query) { + return query->kind == QueryKind::Invalid; + } }; // No-op query (i.e. a blank line). struct NoOpQuery : Query { - NoOpQuery() : Query(QK_NoOp) {} + NoOpQuery() : Query(QueryKind::NoOp) {} bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; - static bool classof(const Query *query) { return query->kind == QK_NoOp; } + static bool classof(const Query *query) { + return query->kind == QueryKind::NoOp; + } }; // Query for "help". struct HelpQuery : Query { - HelpQuery() : Query(QK_Help) {} + HelpQuery() : Query(QueryKind::Help) {} bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; - static bool classof(const Query *query) { return query->kind == QK_Help; } + static bool classof(const Query *query) { + return query->kind == QueryKind::Help; + } }; // Query for "match MATCHER". struct MatchQuery : Query { MatchQuery(StringRef source, const matcher::DynMatcher &matcher) - : Query(QK_Match), matcher(matcher), source(source) {} + : Query(QueryKind::Match), matcher(matcher), source(source) {} bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; const matcher::DynMatcher matcher; StringRef source; - static bool classof(const Query *query) { return query->kind == QK_Match; } + static bool classof(const Query *query) { + return query->kind == QueryKind::Match; + } }; } // namespace mlir::query diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index 10839463707f0..2d69e933bc5d7 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -19,21 +19,6 @@ namespace mlir::query::matcher::internal { // Simple structure to hold information for one token from the parser. struct Parser::TokenInfo { - // Different possible tokens. - enum TokenKind { - TK_Eof, - TK_NewLine, - TK_OpenParen, - TK_CloseParen, - TK_Comma, - TK_Period, - TK_Literal, - TK_Ident, - TK_InvalidChar, - TK_CodeCompletion, - TK_Error - }; - TokenInfo() = default; // Method to set the kind and text of the token @@ -43,7 +28,7 @@ struct Parser::TokenInfo { } llvm::StringRef text; - TokenKind kind = TK_Eof; + TokenKind kind = TokenKind::Eof; SourceRange range; VariantValue value; }; @@ -76,7 +61,7 @@ class Parser::CodeTokenizer { // Skip any newline tokens TokenInfo skipNewlines() { - while (nextToken.kind == TokenInfo::TK_NewLine) + while (nextToken.kind == TokenKind::NewLine) nextToken = getNextToken(); return nextToken; } @@ -84,11 +69,11 @@ class Parser::CodeTokenizer { // Consume and return next token, ignoring newlines TokenInfo consumeNextTokenIgnoreNewlines() { skipNewlines(); - return nextToken.kind == TokenInfo::TK_Eof ? nextToken : consumeNextToken(); + return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken(); } // Return kind of next token - TokenInfo::TokenKind nextTokenKind() const { return nextToken.kind; } + TokenKind nextTokenKind() const { return nextToken.kind; } private: // Helper function to get the first character as a new StringRef and drop it @@ -108,7 +93,7 @@ class Parser::CodeTokenizer { // Code completion case if (codeCompletionLocation && codeCompletionLocation <= code.data()) { - result.set(TokenInfo::TK_CodeCompletion, + result.set(TokenKind::CodeCompletion, llvm::StringRef(codeCompletionLocation, 0)); codeCompletionLocation = nullptr; return result; @@ -116,7 +101,7 @@ class Parser::CodeTokenizer { // End of file case if (code.empty()) { - result.set(TokenInfo::TK_Eof, ""); + result.set(TokenKind::Eof, ""); return result; } @@ -126,21 +111,21 @@ class Parser::CodeTokenizer { code = code.drop_until([](char c) { return c == '\n'; }); return getNextToken(); case ',': - result.set(TokenInfo::TK_Comma, firstCharacterAndDrop(code)); + result.set(TokenKind::Comma, firstCharacterAndDrop(code)); break; case '.': - result.set(TokenInfo::TK_Period, firstCharacterAndDrop(code)); + result.set(TokenKind::Period, firstCharacterAndDrop(code)); break; case '\n': ++line; startOfLine = code.drop_front(); - result.set(TokenInfo::TK_NewLine, firstCharacterAndDrop(code)); + result.set(TokenKind::NewLine, firstCharacterAndDrop(code)); break; case '(': - result.set(TokenInfo::TK_OpenParen, firstCharacterAndDrop(code)); + result.set(TokenKind::OpenParen, firstCharacterAndDrop(code)); break; case ')': - result.set(TokenInfo::TK_CloseParen, firstCharacterAndDrop(code)); + result.set(TokenKind::CloseParen, firstCharacterAndDrop(code)); break; case '"': case '\'': @@ -170,7 +155,7 @@ class Parser::CodeTokenizer { continue; } if (code[length] == marker) { - result->kind = TokenInfo::TK_Literal; + result->kind = TokenKind::Literal; result->text = code.substr(0, length + 1); result->value = code.substr(1, length - 1); code = code.drop_front(length + 1); @@ -184,7 +169,7 @@ class Parser::CodeTokenizer { range.end = currentLocation(); error->addError(range, Diagnostics::ErrorType::ParserStringError) << errorText; - result->kind = TokenInfo::TK_Error; + result->kind = TokenKind::Error; } void parseIdentifierOrInvalid(TokenInfo *result) { @@ -198,7 +183,7 @@ class Parser::CodeTokenizer { // location to become a code completion token. if (codeCompletionLocation == code.data() + tokenLength) { codeCompletionLocation = nullptr; - result->kind = TokenInfo::TK_CodeCompletion; + result->kind = TokenKind::CodeCompletion; result->text = code.substr(0, tokenLength); code = code.drop_front(tokenLength); return; @@ -207,11 +192,11 @@ class Parser::CodeTokenizer { break; ++tokenLength; } - result->kind = TokenInfo::TK_Ident; + result->kind = TokenKind::Ident; result->text = code.substr(0, tokenLength); code = code.drop_front(tokenLength); } else { - result->kind = TokenInfo::TK_InvalidChar; + result->kind = TokenKind::InvalidChar; result->text = code.substr(0, 1); code = code.drop_front(1); } @@ -270,7 +255,7 @@ struct Parser::ScopedContextEntry { bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { const TokenInfo nameToken = tokenizer->consumeNextToken(); - if (tokenizer->nextTokenKind() != TokenInfo::TK_OpenParen) { + if (tokenizer->nextTokenKind() != TokenKind::OpenParen) { // Parse as a named value. auto namedValue = namedValues ? namedValues->lookup(nameToken.text) : VariantValue(); @@ -281,7 +266,7 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { return false; } - if (tokenizer->nextTokenKind() == TokenInfo::TK_NewLine) { + if (tokenizer->nextTokenKind() == TokenKind::NewLine) { error->addError(tokenizer->peekNextToken().range, Diagnostics::ErrorType::ParserNoOpenParen) << "NewLine"; @@ -290,10 +275,10 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { // If the syntax is correct and the name is not a matcher either, report // an unknown named value. - if ((tokenizer->nextTokenKind() == TokenInfo::TK_Comma || - tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen || - tokenizer->nextTokenKind() == TokenInfo::TK_NewLine || - tokenizer->nextTokenKind() == TokenInfo::TK_Eof) && + if ((tokenizer->nextTokenKind() == TokenKind::Comma || + tokenizer->nextTokenKind() == TokenKind::CloseParen || + tokenizer->nextTokenKind() == TokenKind::NewLine || + tokenizer->nextTokenKind() == TokenKind::Eof) && !sema->lookupMatcherCtor(nameToken.text)) { error->addError(nameToken.range, Diagnostics::ErrorType::RegistryValueNotFound) @@ -305,9 +290,9 @@ bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { tokenizer->skipNewlines(); - assert(nameToken.kind == TokenInfo::TK_Ident); + assert(nameToken.kind == TokenKind::Ident); TokenInfo openToken = tokenizer->consumeNextToken(); - if (openToken.kind != TokenInfo::TK_OpenParen) { + if (openToken.kind != TokenKind::OpenParen) { error->addError(openToken.range, Diagnostics::ErrorType::ParserNoOpenParen) << openToken.text; return false; @@ -324,8 +309,8 @@ bool Parser::parseMatcherArgs(std::vector &args, MatcherCtor ctor, const TokenInfo &nameToken, TokenInfo &endToken) { ScopedContextEntry sce(this, ctor); - while (tokenizer->nextTokenKind() != TokenInfo::TK_Eof) { - if (tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen) { + while (tokenizer->nextTokenKind() != TokenKind::Eof) { + if (tokenizer->nextTokenKind() == TokenKind::CloseParen) { // end of args. endToken = tokenizer->consumeNextToken(); break; @@ -334,7 +319,7 @@ bool Parser::parseMatcherArgs(std::vector &args, MatcherCtor ctor, if (!args.empty()) { // We must find a , token to continue. TokenInfo commaToken = tokenizer->consumeNextToken(); - if (commaToken.kind != TokenInfo::TK_Comma) { + if (commaToken.kind != TokenKind::Comma) { error->addError(commaToken.range, Diagnostics::ErrorType::ParserNoComma) << commaToken.text; return false; @@ -380,7 +365,7 @@ bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, } // Check for the missing closing parenthesis - if (endToken.kind != TokenInfo::TK_CloseParen) { + if (endToken.kind != TokenKind::CloseParen) { error->addError(openToken.range, Diagnostics::ErrorType::ParserNoCloseParen) << nameToken.text; return false; @@ -425,7 +410,7 @@ Parser::getNamedValueCompletions(ArrayRef acceptedTypes) { void Parser::addExpressionCompletions() { const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines(); - assert(compToken.kind == TokenInfo::TK_CodeCompletion); + assert(compToken.kind == TokenKind::CodeCompletion); // We cannot complete code if there is an invalid element on the context // stack. @@ -447,31 +432,31 @@ void Parser::addExpressionCompletions() { // Parse an bool Parser::parseExpressionImpl(VariantValue *value) { switch (tokenizer->nextTokenKind()) { - case TokenInfo::TK_Literal: + case TokenKind::Literal: *value = tokenizer->consumeNextToken().value; return true; - case TokenInfo::TK_Ident: + case TokenKind::Ident: return parseIdentifierPrefixImpl(value); - case TokenInfo::TK_CodeCompletion: + case TokenKind::CodeCompletion: addExpressionCompletions(); return false; - case TokenInfo::TK_Eof: + case TokenKind::Eof: error->addError(tokenizer->consumeNextToken().range, Diagnostics::ErrorType::ParserNoCode); return false; - case TokenInfo::TK_Error: + case TokenKind::Error: // This error was already reported by the tokenizer. return false; - case TokenInfo::TK_NewLine: - case TokenInfo::TK_OpenParen: - case TokenInfo::TK_CloseParen: - case TokenInfo::TK_Comma: - case TokenInfo::TK_Period: - case TokenInfo::TK_InvalidChar: + case TokenKind::NewLine: + case TokenKind::OpenParen: + case TokenKind::CloseParen: + case TokenKind::Comma: + case TokenKind::Period: + case TokenKind::InvalidChar: const TokenInfo token = tokenizer->consumeNextToken(); error->addError(token.range, Diagnostics::ErrorType::ParserInvalidToken) - << (token.kind == TokenInfo::TK_NewLine ? "NewLine" : token.text); + << (token.kind == TokenKind::NewLine ? "NewLine" : token.text); return false; } @@ -516,8 +501,8 @@ bool Parser::parseExpression(llvm::StringRef &code, Sema *sema, if (!parser.parseExpressionImpl(value)) return false; auto nextToken = tokenizer.peekNextToken(); - if (nextToken.kind != TokenInfo::TK_Eof && - nextToken.kind != TokenInfo::TK_NewLine) { + if (nextToken.kind != TokenKind::Eof && + nextToken.kind != TokenKind::NewLine) { error->addError(tokenizer.peekNextToken().range, Diagnostics::ErrorType::ParserTrailingCode); return false; diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp index 77c330450e10f..03af985b9f328 100644 --- a/mlir/lib/Query/Matcher/VariantValue.cpp +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -51,15 +51,17 @@ void VariantMatcher::reset() { value.reset(); } std::string VariantMatcher::getTypeAsString() const { return ""; } -VariantValue::VariantValue(const VariantValue &other) : type(VT_Nothing) { +VariantValue::VariantValue(const VariantValue &other) + : type(ValueType::Nothing) { *this = other; } -VariantValue::VariantValue(const StringRef string) : type(VT_String) { +VariantValue::VariantValue(const StringRef string) : type(ValueType::String) { value.String = new StringRef(string); } -VariantValue::VariantValue(const VariantMatcher &matcher) : type(VT_Matcher) { +VariantValue::VariantValue(const VariantMatcher &matcher) + : type(ValueType::Matcher) { value.Matcher = new VariantMatcher(matcher); } @@ -70,14 +72,14 @@ VariantValue &VariantValue::operator=(const VariantValue &other) { return *this; reset(); switch (other.type) { - case VT_String: + case ValueType::String: setString(other.getString()); break; - case VT_Matcher: + case ValueType::Matcher: setMatcher(other.getMatcher()); break; - case VT_Nothing: - type = VT_Nothing; + case ValueType::Nothing: + type = ValueType::Nothing; break; } return *this; @@ -85,20 +87,20 @@ VariantValue &VariantValue::operator=(const VariantValue &other) { void VariantValue::reset() { switch (type) { - case VT_String: + case ValueType::String: delete value.String; break; - case VT_Matcher: + case ValueType::Matcher: delete value.Matcher; break; // Cases that do nothing. - case VT_Nothing: + case ValueType::Nothing: break; } - type = VT_Nothing; + type = ValueType::Nothing; } -bool VariantValue::isString() const { return type == VT_String; } +bool VariantValue::isString() const { return type == ValueType::String; } const StringRef &VariantValue::getString() const { assert(isString()); @@ -107,11 +109,11 @@ const StringRef &VariantValue::getString() const { void VariantValue::setString(const StringRef &newValue) { reset(); - type = VT_String; + type = ValueType::String; value.String = new StringRef(newValue); } -bool VariantValue::isMatcher() const { return type == VT_Matcher; } +bool VariantValue::isMatcher() const { return type == ValueType::Matcher; } const VariantMatcher &VariantValue::getMatcher() const { assert(isMatcher()); @@ -120,17 +122,17 @@ const VariantMatcher &VariantValue::getMatcher() const { void VariantValue::setMatcher(const VariantMatcher &newValue) { reset(); - type = VT_Matcher; + type = ValueType::Matcher; value.Matcher = new VariantMatcher(newValue); } std::string VariantValue::getTypeAsString() const { switch (type) { - case VT_String: + case ValueType::String: return "String"; - case VT_Matcher: + case ValueType::Matcher: return "Matcher"; - case VT_Nothing: + case ValueType::Nothing: return "Nothing"; } llvm_unreachable("Invalid Type"); diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp index 60e6bbf1ca574..5051dbf580c29 100644 --- a/mlir/lib/Query/QueryParser.cpp +++ b/mlir/lib/Query/QueryParser.cpp @@ -114,12 +114,12 @@ QueryRef QueryParser::endQuery(QueryRef queryRef) { namespace { -enum ParsedQueryKind { - PQK_Invalid, - PQK_Comment, - PQK_NoOp, - PQK_Help, - PQK_Match, +enum class ParsedQueryKind { + Invalid, + Comment, + NoOp, + Help, + Match, }; QueryRef @@ -144,27 +144,28 @@ QueryRef QueryParser::completeMatcherExpression() { QueryRef QueryParser::doParse() { llvm::StringRef commandStr; - ParsedQueryKind qKind = LexOrCompleteWord(this, commandStr) - .Case("", PQK_NoOp) - .Case("#", PQK_Comment, /*isCompletion=*/false) - .Case("help", PQK_Help) - .Case("m", PQK_Match, /*isCompletion=*/false) - .Case("match", PQK_Match) - .Default(PQK_Invalid); + ParsedQueryKind qKind = + LexOrCompleteWord(this, commandStr) + .Case("", ParsedQueryKind::NoOp) + .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false) + .Case("help", ParsedQueryKind::Help) + .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false) + .Case("match", ParsedQueryKind::Match) + .Default(ParsedQueryKind::Invalid); switch (qKind) { - case PQK_Comment: - case PQK_NoOp: + case ParsedQueryKind::Comment: + case ParsedQueryKind::NoOp: line = line.drop_until([](char c) { return c == '\n'; }); line = line.drop_while([](char c) { return c == '\n'; }); if (line.empty()) return new NoOpQuery; return doParse(); - case PQK_Help: + case ParsedQueryKind::Help: return endQuery(new HelpQuery); - case PQK_Match: { + case ParsedQueryKind::Match: { if (completionPos) { return completeMatcherExpression(); } @@ -185,7 +186,7 @@ QueryRef QueryParser::doParse() { return Q; } - case PQK_Invalid: + case ParsedQueryKind::Invalid: return new InvalidQuery("unknown command: " + commandStr); } From 30a563f58cfbd2774e7db911d749b1b113b871af Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Tue, 1 Aug 2023 17:33:46 +0100 Subject: [PATCH 09/12] Working Argkind enum class --- mlir/include/mlir/Query/Matcher/Marshallers.h | 4 ++-- .../include/mlir/Query/Matcher/VariantValue.h | 18 ++---------------- mlir/lib/Query/Matcher/Registry.cpp | 19 +++++++++++++++---- mlir/lib/Query/Matcher/VariantValue.cpp | 10 ---------- 4 files changed, 19 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h index bada7c12aedb4..43d6f990c8583 100644 --- a/mlir/include/mlir/Query/Matcher/Marshallers.h +++ b/mlir/include/mlir/Query/Matcher/Marshallers.h @@ -41,7 +41,7 @@ struct ArgTypeTraits { return value.getString(); } - static ArgKind getKind() { return ArgKind(ArgKind::AK_String); } + static ArgKind getKind() { return ArgKind::String; } static std::optional getBestGuess(const VariantValue &) { return std::nullopt; @@ -59,7 +59,7 @@ struct ArgTypeTraits { return *value.getMatcher().getDynMatcher(); } - static ArgKind getKind() { return ArgKind(ArgKind::AK_Matcher); } + static ArgKind getKind() { return ArgKind::Matcher; } static std::optional getBestGuess(const VariantValue &) { return std::nullopt; diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h index 121aebd111c0a..42f93984a4f5a 100644 --- a/mlir/include/mlir/Query/Matcher/VariantValue.h +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -19,22 +19,8 @@ namespace mlir::query::matcher { -// Kind identifier that supports all types that VariantValue can contain. -class ArgKind { -public: - enum Kind { AK_Matcher, AK_String }; - ArgKind(Kind k) : k(k) {} - - Kind getArgKind() const { return k; } - - bool operator<(const ArgKind &other) const { return k < other.k; } - - // String representation of the type. - std::string asString() const; - -private: - Kind k; -}; +// All types that VariantValue can contain. +enum class ArgKind { Matcher, String }; // A variant matcher object to abstract simple and complex matchers into a // single object type. diff --git a/mlir/lib/Query/Matcher/Registry.cpp b/mlir/lib/Query/Matcher/Registry.cpp index a56ebbdf18959..fef3398062b54 100644 --- a/mlir/lib/Query/Matcher/Registry.cpp +++ b/mlir/lib/Query/Matcher/Registry.cpp @@ -29,6 +29,17 @@ using IsConstantOp = detail::constant_op_matcher(); using HasOpAttrName = detail::AttrOpMatcher(StringRef); using HasOpName = detail::NameOpMatcher(StringRef); +// Enum to string for autocomplete. +static std::string asArgString(ArgKind kind) { + switch (kind) { + case ArgKind::Matcher: + return "Matcher"; + case ArgKind::String: + return "String"; + } + llvm_unreachable("Unhandled ArgKind"); +} + class RegistryMaps { public: RegistryMaps(); @@ -90,7 +101,7 @@ std::vector Registry::getAcceptedCompletionTypes( // Starting with the above seed of acceptable top-level matcher types, compute // the acceptable type set for the argument indicated by each context element. std::set typeSet; - typeSet.insert(ArgKind(ArgKind::AK_Matcher)); + typeSet.insert(ArgKind::Matcher); for (const auto &ctxEntry : context) { MatcherCtor ctor = ctxEntry.first; @@ -119,7 +130,7 @@ Registry::getMatcherCompletions(ArrayRef acceptedTypes) { std::vector> argKinds(numArgs); for (const ArgKind &kind : acceptedTypes) { - if (kind.getArgKind() != kind.AK_Matcher) + if (kind != ArgKind::Matcher) continue; for (unsigned arg = 0; arg != numArgs; ++arg) @@ -143,7 +154,7 @@ Registry::getMatcherCompletions(ArrayRef acceptedTypes) { OS << "|"; firstArgKind = false; - OS << argKind.asString(); + OS << asArgString(argKind); } } @@ -152,7 +163,7 @@ Registry::getMatcherCompletions(ArrayRef acceptedTypes) { if (argKinds.empty()) typedText += ")"; - else if (argKinds[0][0].getArgKind() == ArgKind::AK_String) + else if (argKinds[0][0] == ArgKind::String) typedText += "\""; completions.emplace_back(typedText, OS.str()); diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp index 03af985b9f328..d3a4916705df7 100644 --- a/mlir/lib/Query/Matcher/VariantValue.cpp +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -13,16 +13,6 @@ namespace mlir::query::matcher { -std::string ArgKind::asString() const { - switch (getArgKind()) { - case AK_String: - return "String"; - case AK_Matcher: - return "Matcher"; - } - llvm_unreachable("Unhandled ArgKind"); -} - VariantMatcher::Payload::~Payload() = default; class VariantMatcher::SinglePayload : public VariantMatcher::Payload { From e3150e2258bb701bb2e87c2b313e9dd58accfe28 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Wed, 2 Aug 2023 14:08:12 +0100 Subject: [PATCH 10/12] Move registry to tools --- mlir/include/mlir/Query/Matcher/Parser.h | 47 ++++++++-------- mlir/include/mlir/Query/Matcher/Registry.h | 30 +++++++++- mlir/include/mlir/Query/QuerySession.h | 7 ++- mlir/lib/Query/Matcher/Parser.cpp | 29 +++++----- mlir/lib/Query/Matcher/Registry.cpp | 62 +++------------------ mlir/lib/Query/QueryParser.cpp | 4 +- mlir/lib/Tools/mlir-query/MlirQueryMain.cpp | 28 +++++++++- 7 files changed, 107 insertions(+), 100 deletions(-) diff --git a/mlir/include/mlir/Query/Matcher/Parser.h b/mlir/include/mlir/Query/Matcher/Parser.h index b7c22f5ecfafd..53ca30fccde38 100644 --- a/mlir/include/mlir/Query/Matcher/Parser.h +++ b/mlir/include/mlir/Query/Matcher/Parser.h @@ -84,6 +84,8 @@ class Parser { // process tokens. class RegistrySema : public Parser::Sema { public: + RegistrySema(const RegistryMaps ®istryData) + : registryData(registryData) {} ~RegistrySema() override; std::optional @@ -99,6 +101,9 @@ class Parser { std::vector getMatcherCompletions(llvm::ArrayRef acceptedTypes) override; + + private: + const RegistryMaps ®istryData; }; using NamedValueMap = llvm::StringMap; @@ -106,44 +111,36 @@ class Parser { // Methods to parse a matcher expression and return a DynMatcher object, // transferring ownership to the caller. static std::optional - parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema, + parseMatcherExpression(llvm::StringRef &matcherCode, + const RegistryMaps ®istryData, const NamedValueMap *namedValues, Diagnostics *error); static std::optional - parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema, - Diagnostics *error) { - return parseMatcherExpression(matcherCode, sema, nullptr, error); - } - static std::optional - parseMatcherExpression(llvm::StringRef &matcherCode, Diagnostics *error) { - return parseMatcherExpression(matcherCode, nullptr, error); + parseMatcherExpression(llvm::StringRef &matcherCode, + const RegistryMaps ®istryData, Diagnostics *error) { + return parseMatcherExpression(matcherCode, registryData, nullptr, error); } // Methods to parse any expression supported by this parser. - static bool parseExpression(llvm::StringRef &code, Sema *sema, + static bool parseExpression(llvm::StringRef &code, + const RegistryMaps ®istryData, const NamedValueMap *namedValues, VariantValue *value, Diagnostics *error); - static bool parseExpression(llvm::StringRef &code, Sema *sema, + static bool parseExpression(llvm::StringRef &code, + const RegistryMaps ®istryData, VariantValue *value, Diagnostics *error) { - return parseExpression(code, sema, nullptr, value, error); - } - static bool parseExpression(llvm::StringRef &code, VariantValue *value, - Diagnostics *error) { - return parseExpression(code, nullptr, value, error); + return parseExpression(code, registryData, nullptr, value, error); } // Methods to complete an expression at a given offset. static std::vector completeExpression(llvm::StringRef &code, unsigned completionOffset, - Sema *sema, const NamedValueMap *namedValues); + const RegistryMaps ®istryData, + const NamedValueMap *namedValues); static std::vector completeExpression(llvm::StringRef &code, unsigned completionOffset, - Sema *sema) { - return completeExpression(code, completionOffset, sema, nullptr); - } - static std::vector - completeExpression(llvm::StringRef &code, unsigned completionOffset) { - return completeExpression(code, completionOffset, nullptr); + const RegistryMaps ®istryData) { + return completeExpression(code, completionOffset, registryData, nullptr); } private: @@ -151,8 +148,8 @@ class Parser { struct ScopedContextEntry; struct TokenInfo; - Parser(CodeTokenizer *tokenizer, Sema *sema, const NamedValueMap *namedValues, - Diagnostics *error); + Parser(CodeTokenizer *tokenizer, const RegistryMaps ®istryData, + const NamedValueMap *namedValues, Diagnostics *error); bool parseExpressionImpl(VariantValue *value); @@ -174,7 +171,7 @@ class Parser { getNamedValueCompletions(ArrayRef acceptedTypes); CodeTokenizer *const tokenizer; - Sema *const sema; + std::unique_ptr sema; const NamedValueMap *const namedValues; Diagnostics *const error; diff --git a/mlir/include/mlir/Query/Matcher/Registry.h b/mlir/include/mlir/Query/Matcher/Registry.h index 8b6f6e5586f0f..06e5ab3abd092 100644 --- a/mlir/include/mlir/Query/Matcher/Registry.h +++ b/mlir/include/mlir/Query/Matcher/Registry.h @@ -19,12 +19,36 @@ #include "Marshallers.h" #include "VariantValue.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include namespace mlir::query::matcher { using MatcherCtor = const internal::MatcherDescriptor *; +using ConstructorMap = + llvm::StringMap>; + +class RegistryMaps { +public: + RegistryMaps() = default; + ~RegistryMaps() = default; + + const ConstructorMap &constructors() const { return constructorMap; } + + template + void registerMatcher(const std::string &name, MatcherType matcher) { + registerMatcherDescriptor(name, + internal::makeMatcherAutoMarshall(matcher, name)); + } + +private: + void registerMatcherDescriptor( + llvm::StringRef matcherName, + std::unique_ptr callback); + + ConstructorMap constructorMap; +}; struct MatcherCompletion { MatcherCompletion() = default; @@ -47,13 +71,15 @@ class Registry { Registry() = delete; static std::optional - lookupMatcherCtor(llvm::StringRef matcherName); + lookupMatcherCtor(llvm::StringRef matcherName, + const RegistryMaps ®istryData); static std::vector getAcceptedCompletionTypes( llvm::ArrayRef> context); static std::vector - getMatcherCompletions(ArrayRef acceptedTypes); + getMatcherCompletions(ArrayRef acceptedTypes, + const RegistryMaps ®istryData); static VariantMatcher constructMatcher(MatcherCtor ctor, internal::SourceRange nameRange, diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h index afe3e3b26c7a1..622f162c5da0f 100644 --- a/mlir/include/mlir/Query/QuerySession.h +++ b/mlir/include/mlir/Query/QuerySession.h @@ -10,6 +10,7 @@ #define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H #include "Query.h" +#include "mlir/Query/Matcher/Registry.h" #include "mlir/Tools/ParseUtilities.h" #include "llvm/ADT/StringMap.h" @@ -20,9 +21,10 @@ class QuerySession { public: QuerySession(Operation *rootOp, const std::shared_ptr &sourceMgr, - unsigned bufferId) + unsigned bufferId, + const mlir::query::matcher::RegistryMaps ®istryData) : rootOp(rootOp), sourceMgr(sourceMgr), bufferId(bufferId), - terminate(false) {} + registryData(registryData), terminate(false) {} const std::shared_ptr &getSourceManager() { return sourceMgr; @@ -31,6 +33,7 @@ class QuerySession { Operation *rootOp; const std::shared_ptr sourceMgr; unsigned bufferId; + const mlir::query::matcher::RegistryMaps ®istryData; bool terminate; llvm::StringMap namedValues; }; diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp index 2d69e933bc5d7..e2f14e2447227 100644 --- a/mlir/lib/Query/Matcher/Parser.cpp +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -463,18 +463,16 @@ bool Parser::parseExpressionImpl(VariantValue *value) { llvm_unreachable("Unknown token kind."); } -static llvm::ManagedStatic defaultRegistrySema; - -Parser::Parser(CodeTokenizer *tokenizer, Sema *sema, +Parser::Parser(CodeTokenizer *tokenizer, const RegistryMaps ®istryData, const NamedValueMap *namedValues, Diagnostics *error) - : tokenizer(tokenizer), sema(sema ? sema : &*defaultRegistrySema), + : tokenizer(tokenizer), sema(std::make_unique(registryData)), namedValues(namedValues), error(error) {} Parser::RegistrySema::~RegistrySema() = default; std::optional Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) { - return Registry::lookupMatcherCtor(matcherName); + return Registry::lookupMatcherCtor(matcherName, registryData); } VariantMatcher Parser::RegistrySema::actOnMatcherExpression( @@ -490,14 +488,15 @@ std::vector Parser::RegistrySema::getAcceptedCompletionTypes( std::vector Parser::RegistrySema::getMatcherCompletions(ArrayRef acceptedTypes) { - return Registry::getMatcherCompletions(acceptedTypes); + return Registry::getMatcherCompletions(acceptedTypes, registryData); } -bool Parser::parseExpression(llvm::StringRef &code, Sema *sema, +bool Parser::parseExpression(llvm::StringRef &code, + const RegistryMaps ®istryData, const NamedValueMap *namedValues, VariantValue *value, Diagnostics *error) { CodeTokenizer tokenizer(code, error); - Parser parser(&tokenizer, sema, namedValues, error); + Parser parser(&tokenizer, registryData, namedValues, error); if (!parser.parseExpressionImpl(value)) return false; auto nextToken = tokenizer.peekNextToken(); @@ -512,22 +511,22 @@ bool Parser::parseExpression(llvm::StringRef &code, Sema *sema, std::vector Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset, - Sema *sema, const NamedValueMap *namedValues) { + const RegistryMaps ®istryData, + const NamedValueMap *namedValues) { Diagnostics error; CodeTokenizer tokenizer(code, &error, completionOffset); - Parser parser(&tokenizer, sema, namedValues, &error); + Parser parser(&tokenizer, registryData, namedValues, &error); VariantValue dummy; parser.parseExpressionImpl(&dummy); return parser.completions; } -std::optional -Parser::parseMatcherExpression(llvm::StringRef &code, Sema *sema, - const NamedValueMap *namedValues, - Diagnostics *error) { +std::optional Parser::parseMatcherExpression( + llvm::StringRef &code, const RegistryMaps ®istryData, + const NamedValueMap *namedValues, Diagnostics *error) { VariantValue value; - if (!parseExpression(code, sema, namedValues, &value, error)) + if (!parseExpression(code, registryData, namedValues, &value, error)) return std::nullopt; if (!value.isMatcher()) { error->addError(SourceRange(), Diagnostics::ErrorType::ParserNotAMatcher); diff --git a/mlir/lib/Query/Matcher/Registry.cpp b/mlir/lib/Query/Matcher/Registry.cpp index fef3398062b54..c64b9e58fb0fc 100644 --- a/mlir/lib/Query/Matcher/Registry.cpp +++ b/mlir/lib/Query/Matcher/Registry.cpp @@ -12,18 +12,12 @@ #include "mlir/Query/Matcher/Registry.h" -#include "mlir/IR/Matchers.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/Support/ManagedStatic.h" #include #include namespace mlir::query::matcher { namespace { -using ConstructorMap = - llvm::StringMap>; - // This is needed because these matchers are defined as overloaded functions. using IsConstantOp = detail::constant_op_matcher(); using HasOpAttrName = detail::AttrOpMatcher(StringRef); @@ -40,60 +34,21 @@ static std::string asArgString(ArgKind kind) { llvm_unreachable("Unhandled ArgKind"); } -class RegistryMaps { -public: - RegistryMaps(); - ~RegistryMaps(); - - const ConstructorMap &constructors() const { return constructorMap; } - -private: - void registerMatcher(llvm::StringRef matcherName, - std::unique_ptr callback); - - ConstructorMap constructorMap; -}; - } // namespace -void RegistryMaps::registerMatcher( +void RegistryMaps::registerMatcherDescriptor( llvm::StringRef matcherName, std::unique_ptr callback) { assert(!constructorMap.contains(matcherName)); constructorMap[matcherName] = std::move(callback); } -// Generate a registry map with all the known matchers. -RegistryMaps::RegistryMaps() { - auto registerOpMatcher = [&](const std::string &name, auto matcher) { - registerMatcher(name, internal::makeMatcherAutoMarshall(matcher, name)); - }; - - // Register matchers using the template function (added in alphabetical order - // for consistency) - registerOpMatcher("hasOpAttrName", static_cast(m_Attr)); - registerOpMatcher("hasOpName", static_cast(m_Op)); - registerOpMatcher("isConstantOp", static_cast(m_Constant)); - registerOpMatcher("isNegInfFloat", m_NegInfFloat); - registerOpMatcher("isNegZeroFloat", m_NegZeroFloat); - registerOpMatcher("isNonZero", m_NonZero); - registerOpMatcher("isOne", m_One); - registerOpMatcher("isOneFloat", m_OneFloat); - registerOpMatcher("isPosInfFloat", m_PosInfFloat); - registerOpMatcher("isPosZeroFloat", m_PosZeroFloat); - registerOpMatcher("isZero", m_Zero); - registerOpMatcher("isZeroFloat", m_AnyZeroFloat); -} - -RegistryMaps::~RegistryMaps() = default; - -static llvm::ManagedStatic registryData; - std::optional -Registry::lookupMatcherCtor(llvm::StringRef matcherName) { - auto it = registryData->constructors().find(matcherName); - return it == registryData->constructors().end() ? std::optional() - : it->second.get(); +Registry::lookupMatcherCtor(llvm::StringRef matcherName, + const RegistryMaps ®istryData) { + auto it = registryData.constructors().find(matcherName); + return it == registryData.constructors().end() ? std::optional() + : it->second.get(); } std::vector Registry::getAcceptedCompletionTypes( @@ -118,11 +73,12 @@ std::vector Registry::getAcceptedCompletionTypes( } std::vector -Registry::getMatcherCompletions(ArrayRef acceptedTypes) { +Registry::getMatcherCompletions(ArrayRef acceptedTypes, + const RegistryMaps ®istryData) { std::vector completions; // Search the registry for acceptable matchers. - for (const auto &m : registryData->constructors()) { + for (const auto &m : registryData.constructors()) { const internal::MatcherDescriptor &matcher = *m.getValue(); StringRef name = m.getKey(); diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp index 5051dbf580c29..05fe815bda4d1 100644 --- a/mlir/lib/Query/QueryParser.cpp +++ b/mlir/lib/Query/QueryParser.cpp @@ -134,7 +134,7 @@ makeInvalidQueryFromDiagnostics(const matcher::internal::Diagnostics &diag) { QueryRef QueryParser::completeMatcherExpression() { std::vector comps = matcher::internal::Parser::completeExpression( - line, completionPos - line.begin(), nullptr, &QS.namedValues); + line, completionPos - line.begin(), QS.registryData, &QS.namedValues); for (const auto &comp : comps) { completions.emplace_back(comp.typedText, comp.matcherDecl); } @@ -175,7 +175,7 @@ QueryRef QueryParser::doParse() { auto origMatcherSource = matcherSource; std::optional matcher = matcher::internal::Parser::parseMatcherExpression( - matcherSource, nullptr, &QS.namedValues, &diag); + matcherSource, QS.registryData, &QS.namedValues, &diag); if (!matcher) { return makeInvalidQueryFromDiagnostics(diag); } diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp index 7f8151d94c4d0..87b757358d323 100644 --- a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp +++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp @@ -13,6 +13,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/mlir-query/MlirQueryMain.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Query/Matcher/Registry.h" #include "mlir/Query/QueryParser.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" @@ -28,6 +30,7 @@ mlir::LogicalResult mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context) { + // Override the default '-h' and use the default PrintHelpMessage() which // won't print options in categories. static llvm::cl::opt help("h", llvm::cl::desc("Alias for -help"), @@ -83,7 +86,30 @@ mlir::LogicalResult mlir::mlirQueryMain(int argc, char **argv, if (!opRef) return failure(); - mlir::query::QuerySession QS(opRef.get(), sourceMgr, bufferId); + mlir::query::matcher::RegistryMaps registryData; + + // This is needed because these matchers are defined as overloaded functions. + using IsConstantOp = mlir::detail::constant_op_matcher(); + using HasOpAttrName = mlir::detail::AttrOpMatcher(StringRef); + using HasOpName = mlir::detail::NameOpMatcher(StringRef); + + // Matchers registered in alphabetical order for consistency: + registryData.registerMatcher("hasOpAttrName", + static_cast(m_Attr)); + registryData.registerMatcher("hasOpName", static_cast(m_Op)); + registryData.registerMatcher("isConstantOp", + static_cast(m_Constant)); + registryData.registerMatcher("isNegInfFloat", m_NegInfFloat); + registryData.registerMatcher("isNegZeroFloat", m_NegZeroFloat); + registryData.registerMatcher("isNonZero", m_NonZero); + registryData.registerMatcher("isOne", m_One); + registryData.registerMatcher("isOneFloat", m_OneFloat); + registryData.registerMatcher("isPosInfFloat", m_PosInfFloat); + registryData.registerMatcher("isPosZeroFloat", m_PosZeroFloat); + registryData.registerMatcher("isZero", m_Zero); + registryData.registerMatcher("isZeroFloat", m_AnyZeroFloat); + + mlir::query::QuerySession QS(opRef.get(), sourceMgr, bufferId, registryData); if (!commands.empty()) { for (auto &command : commands) { mlir::query::QueryRef queryRef = From c5f1761d4a1ac694d3bd6fd269b99950c1e7bdf6 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Wed, 2 Aug 2023 14:35:17 +0100 Subject: [PATCH 11/12] Move to registryData --- mlir/include/mlir/Query/QuerySession.h | 5 ++-- .../mlir/Tools/mlir-query/MlirQueryMain.h | 5 +++- mlir/lib/Tools/mlir-query/MlirQueryMain.cpp | 30 ++----------------- mlir/tools/mlir-query/mlir-query.cpp | 28 ++++++++++++++++- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h index 622f162c5da0f..84a9edc5a3953 100644 --- a/mlir/include/mlir/Query/QuerySession.h +++ b/mlir/include/mlir/Query/QuerySession.h @@ -21,8 +21,7 @@ class QuerySession { public: QuerySession(Operation *rootOp, const std::shared_ptr &sourceMgr, - unsigned bufferId, - const mlir::query::matcher::RegistryMaps ®istryData) + unsigned bufferId, const matcher::RegistryMaps ®istryData) : rootOp(rootOp), sourceMgr(sourceMgr), bufferId(bufferId), registryData(registryData), terminate(false) {} @@ -33,7 +32,7 @@ class QuerySession { Operation *rootOp; const std::shared_ptr sourceMgr; unsigned bufferId; - const mlir::query::matcher::RegistryMaps ®istryData; + const matcher::RegistryMaps ®istryData; bool terminate; llvm::StringMap namedValues; }; diff --git a/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h index 1fa5bc2b78605..54579beb5e794 100644 --- a/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h +++ b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h @@ -14,13 +14,16 @@ #ifndef MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H #define MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H +#include "mlir/Query/Matcher/Registry.h" #include "mlir/Support/LogicalResult.h" namespace mlir { class MLIRContext; -LogicalResult mlirQueryMain(int argc, char **argv, MLIRContext &context); +LogicalResult +mlirQueryMain(int argc, char **argv, MLIRContext &context, + const mlir::query::matcher::RegistryMaps ®istryData); } // namespace mlir diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp index 87b757358d323..ece3b0bf898fc 100644 --- a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp +++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp @@ -13,8 +13,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/mlir-query/MlirQueryMain.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Query/Matcher/Registry.h" #include "mlir/Query/QueryParser.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" @@ -28,8 +26,9 @@ // Query Parser //===----------------------------------------------------------------------===// -mlir::LogicalResult mlir::mlirQueryMain(int argc, char **argv, - MLIRContext &context) { +mlir::LogicalResult +mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context, + const mlir::query::matcher::RegistryMaps ®istryData) { // Override the default '-h' and use the default PrintHelpMessage() which // won't print options in categories. @@ -86,29 +85,6 @@ mlir::LogicalResult mlir::mlirQueryMain(int argc, char **argv, if (!opRef) return failure(); - mlir::query::matcher::RegistryMaps registryData; - - // This is needed because these matchers are defined as overloaded functions. - using IsConstantOp = mlir::detail::constant_op_matcher(); - using HasOpAttrName = mlir::detail::AttrOpMatcher(StringRef); - using HasOpName = mlir::detail::NameOpMatcher(StringRef); - - // Matchers registered in alphabetical order for consistency: - registryData.registerMatcher("hasOpAttrName", - static_cast(m_Attr)); - registryData.registerMatcher("hasOpName", static_cast(m_Op)); - registryData.registerMatcher("isConstantOp", - static_cast(m_Constant)); - registryData.registerMatcher("isNegInfFloat", m_NegInfFloat); - registryData.registerMatcher("isNegZeroFloat", m_NegZeroFloat); - registryData.registerMatcher("isNonZero", m_NonZero); - registryData.registerMatcher("isOne", m_One); - registryData.registerMatcher("isOneFloat", m_OneFloat); - registryData.registerMatcher("isPosInfFloat", m_PosInfFloat); - registryData.registerMatcher("isPosZeroFloat", m_PosZeroFloat); - registryData.registerMatcher("isZero", m_Zero); - registryData.registerMatcher("isZeroFloat", m_AnyZeroFloat); - mlir::query::QuerySession QS(opRef.get(), sourceMgr, bufferId, registryData); if (!commands.empty()) { for (auto &command : commands) { diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp index 1efbebad1bf34..7345ac330dc35 100644 --- a/mlir/tools/mlir-query/mlir-query.cpp +++ b/mlir/tools/mlir-query/mlir-query.cpp @@ -13,11 +13,18 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" #include "mlir/InitAllDialects.h" +#include "mlir/Query/Matcher/Registry.h" #include "mlir/Tools/mlir-query/MlirQueryMain.h" using namespace mlir; +// This is needed because these matchers are defined as overloaded functions. +using HasOpAttrName = detail::AttrOpMatcher(StringRef); +using HasOpName = detail::NameOpMatcher(StringRef); +using IsConstantOp = detail::constant_op_matcher(); + namespace test { #ifdef MLIR_INCLUDE_TESTS void registerTestDialect(DialectRegistry &); @@ -28,10 +35,29 @@ int main(int argc, char **argv) { DialectRegistry registry; registerAllDialects(registry); + + query::matcher::RegistryMaps registryData; + + // Matchers registered in alphabetical order for consistency: + registryData.registerMatcher("hasOpAttrName", + static_cast(m_Attr)); + registryData.registerMatcher("hasOpName", static_cast(m_Op)); + registryData.registerMatcher("isConstantOp", + static_cast(m_Constant)); + registryData.registerMatcher("isNegInfFloat", m_NegInfFloat); + registryData.registerMatcher("isNegZeroFloat", m_NegZeroFloat); + registryData.registerMatcher("isNonZero", m_NonZero); + registryData.registerMatcher("isOne", m_One); + registryData.registerMatcher("isOneFloat", m_OneFloat); + registryData.registerMatcher("isPosInfFloat", m_PosInfFloat); + registryData.registerMatcher("isPosZeroFloat", m_PosZeroFloat); + registryData.registerMatcher("isZero", m_Zero); + registryData.registerMatcher("isZeroFloat", m_AnyZeroFloat); + #ifdef MLIR_INCLUDE_TESTS test::registerTestDialect(registry); #endif MLIRContext context(registry); - return failed(mlirQueryMain(argc, argv, context)); + return failed(mlirQueryMain(argc, argv, context, registryData)); } From be9ee6a8ec31763069a89690a316329d20a368a5 Mon Sep 17 00:00:00 2001 From: Devajith Valaparambil Sreeramaswamy Date: Wed, 2 Aug 2023 15:49:38 +0100 Subject: [PATCH 12/12] Print error on invalid file --- mlir/lib/Tools/mlir-query/MlirQueryMain.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp index ece3b0bf898fc..7ea492b48611f 100644 --- a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp +++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp @@ -71,6 +71,7 @@ mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context, std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); if (!file) { + llvm::errs() << errorMessage << "\n"; return failure(); }