Skip to content

[mlir][mlir-query] Add inital draft for mlir-query tool #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
119 changes: 119 additions & 0 deletions mlir/include/mlir/Query/Matcher/Diagnostics.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//===--- Diagnostics.h - Helper class for error diagnostics -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Diagnostics class to manage error messages. Implementation shares similarity
// to clang-query Diagnostics.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include <string>
#include <vector>

namespace mlir::query::matcher::internal {

// Represents the line and column numbers in a source query.
struct SourceLocation {
unsigned line{};
unsigned column{};
};

// Represents a range in a source query, defined by its start and end locations.
struct SourceRange {
SourceLocation start{};
SourceLocation end{};
};

// Diagnostics class to manage error messages.
class Diagnostics {
public:
// All errors from the system.
enum class ErrorType {
None,

// Parser Errors
ParserFailedToBuildMatcher,
ParserInvalidToken,
ParserNoCloseParen,
ParserNoCode,
ParserNoComma,
ParserNoOpenParen,
ParserNotAMatcher,
ParserOverloadedType,
ParserStringError,
ParserTrailingCode,

// Registry Errors
RegistryMatcherNotFound,
RegistryValueNotFound,
RegistryWrongArgCount,
RegistryWrongArgType
};

// Helper stream class for constructing error messages.
class ArgStream {
public:
ArgStream(std::vector<std::string> *out) : out(out) {}
template <class T>
ArgStream &operator<<(const T &arg) {
return operator<<(llvm::Twine(arg));
}
ArgStream &operator<<(const llvm::Twine &arg);

private:
std::vector<std::string> *out;
};

// Add an error message with the specified range and error type.
// Returns an ArgStream object to allow constructing the error message using
// the << operator.
ArgStream addError(SourceRange range, ErrorType error);

// Print all error messages to the specified output stream.
void print(llvm::raw_ostream &OS) const;

private:
// Information stored for one frame of the context.
struct ContextFrame {
SourceRange range;
std::vector<std::string> args;
};

// Information stored for each error found.
struct ErrorContent {
std::vector<ContextFrame> contextStack;
struct Message {
SourceRange range;
ErrorType type;
std::vector<std::string> args;
};
std::vector<Message> messages;
};

// Get an array reference to the error contents.
llvm::ArrayRef<ErrorContent> errors() const { return errorValues; }

void printMessage(const ErrorContent::Message &message,
const llvm::Twine Prefix, llvm::raw_ostream &OS) const;

void printErrorContent(const ErrorContent &content,
llvm::raw_ostream &OS) const;

std::vector<ContextFrame> contextStack;
std::vector<ErrorContent> errorValues;
};

} // namespace mlir::query::matcher::internal

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H
196 changes: 196 additions & 0 deletions mlir/include/mlir/Query/Matcher/Marshallers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains function templates and classes to wrap matcher construct
// functions. It provides a collection of template function and classes that
// present a generic marshalling layer on top of matcher construct functions.
// The registry uses these to export all marshaller constructors with a uniform
// interface. This mechanism takes inspiration from clang-query.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

#include "Diagnostics.h"
#include "VariantValue.h"

namespace mlir::query::matcher::internal {

// Helper template class for jumping from argument type to the correct is/get
// functions in VariantValue. This is used for verifying and extracting the
// matcher arguments.
template <class T>
struct ArgTypeTraits;
template <class T>
struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};

template <>
struct ArgTypeTraits<StringRef> {

static bool hasCorrectType(const VariantValue &value) {
return value.isString();
}

static const StringRef &get(const VariantValue &value) {
return value.getString();
}

static ArgKind getKind() { return ArgKind::String; }

static std::optional<std::string> getBestGuess(const VariantValue &) {
return std::nullopt;
}
};

template <>
struct ArgTypeTraits<DynMatcher> {

static bool hasCorrectType(const VariantValue &value) {
return value.isMatcher();
}

static DynMatcher get(const VariantValue &value) {
return *value.getMatcher().getDynMatcher();
}

static ArgKind getKind() { return ArgKind::Matcher; }

static std::optional<std::string> getBestGuess(const VariantValue &) {
return std::nullopt;
}
};

// Interface for generic matcher descriptor.
// Offers a create() method that constructs the matcher from the provided
// arguments.
class MatcherDescriptor {
public:
virtual ~MatcherDescriptor() = default;
virtual VariantMatcher create(SourceRange nameRange,
const ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;

// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;

// Append the set of argument types accepted for argument 'ArgNo' to
// 'ArgKinds'.
virtual void getArgKinds(unsigned argNo,
std::vector<ArgKind> &argKinds) const = 0;
};

class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
public:
using MarshallerType = VariantMatcher (*)(void (*func)(),
StringRef matcherName,
SourceRange nameRange,
ArrayRef<ParserValue> args,
Diagnostics *error);

// Marshaller Function to unpack the arguments and call Func. Func is the
// Matcher construct function. This is the function that the matcher
// expressions would use to create the matcher.
FixedArgCountMatcherDescriptor(MarshallerType marshaller, void (*func)(),
StringRef matcherName,
ArrayRef<ArgKind> argKinds)
: marshaller(marshaller), func(func), matcherName(matcherName),
argKinds(argKinds.begin(), argKinds.end()) {}

VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error) const override {
return marshaller(func, matcherName, nameRange, args, error);
}

unsigned getNumArgs() const override { return argKinds.size(); }

void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
kinds.push_back(argKinds[argNo]);
}

private:
const MarshallerType marshaller;
void (*const func)();
const StringRef matcherName;
const std::vector<ArgKind> argKinds;
};

// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
ArrayRef<ParserValue> args, Diagnostics *error) {
if (args.size() != expectedArgCount) {
error->addError(nameRange, Diagnostics::ErrorType::RegistryWrongArgCount)
<< expectedArgCount << args.size();
return false;
}
return true;
}

// Helper function for checking argument type
template <typename ArgType, size_t Index>
inline bool checkArgTypeAtIndex(StringRef matcherName,
ArrayRef<ParserValue> args,
Diagnostics *error) {
if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
error->addError(args[Index].range,
Diagnostics::ErrorType::RegistryWrongArgType)
<< matcherName << Index + 1;
return false;
}
return true;
}

// Marshaller function for fixed number of arguments
template <typename ReturnType, typename... ArgTypes, size_t... Is>
static VariantMatcher
matcherMarshallFixedImpl(void (*func)(), StringRef matcherName,
SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error, std::index_sequence<Is...>) {
using FuncType = ReturnType (*)(ArgTypes...);

// Check if the argument count matches the expected count
if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error)) {
return VariantMatcher();
}

// Check if each argument at the corresponding index has the correct type
if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
ReturnType fnPointer = reinterpret_cast<FuncType>(func)(
ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
return VariantMatcher::SingleMatcher(
*DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
} else {
return VariantMatcher();
}
}

template <typename ReturnType, typename... ArgTypes>
static VariantMatcher
matcherMarshallFixed(void (*func)(), StringRef matcherName,
SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error) {
return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
func, matcherName, nameRange, args, error,
std::index_sequence_for<ArgTypes...>{});
}

// Fixed number of arguments overload
template <typename ReturnType, typename... ArgTypes>
std::unique_ptr<MatcherDescriptor>
makeMatcherAutoMarshall(ReturnType (*func)(ArgTypes...),
StringRef matcherName) {
// Create a vector of argument kinds
std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
return std::make_unique<FixedArgCountMatcherDescriptor>(
matcherMarshallFixed<ReturnType, ArgTypes...>,
reinterpret_cast<void (*)()>(func), matcherName, argKinds);
}

} // namespace mlir::query::matcher::internal

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
42 changes: 42 additions & 0 deletions mlir/include/mlir/Query/Matcher/MatchFinder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

//===- MatchFinder.h - Structural query framework ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the MatchFinder class, which is used to find operations
// that match a given matcher.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H

#include "MatchersInternal.h"

namespace mlir::query::matcher {

// MatchFinder is used to find all operations that match a given matcher.
class MatchFinder {
public:
// Returns all operations that match the given matcher.
static std::vector<Operation *> getMatches(Operation *root,
DynMatcher matcher) {
std::vector<Operation *> matches;

// Simple match finding with walk.
root->walk([&](Operation *subOp) {
if (matcher.match(subOp))
matches.push_back(subOp);
});

return matches;
}
};

} // namespace mlir::query::matcher

#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
Loading