[Mlir-commits] [mlir] 9ad64a5 - [mlir:PDLL] Add support for C++ generation

River Riddle llvmlistbot at llvm.org
Sat Feb 26 11:26:41 PST 2022


Author: River Riddle
Date: 2022-02-26T11:08:51-08:00
New Revision: 9ad64a5c78a98c4f9eb4fff7c9c7665925db6907

URL: https://github.com/llvm/llvm-project/commit/9ad64a5c78a98c4f9eb4fff7c9c7665925db6907
DIFF: https://github.com/llvm/llvm-project/commit/9ad64a5c78a98c4f9eb4fff7c9c7665925db6907.diff

LOG: [mlir:PDLL] Add support for C++ generation

This commits adds a C++ generator to PDLL that generates wrapper PDL patterns
directly usable in C++ code, and also generates the definitions of native constraints/rewrites
that have code bodies specified in PDLL. This generator is effectively the PDLL equivalent of
the current DRR generator, and will allow easy replacement of DRR patterns with PDLL patterns.
A followup will start to utilize this for end-to-end integration testing and show case how to
use this as a drop-in replacement for DRR tablegen usage.

Differential Revision: https://reviews.llvm.org/D119781

Added: 
    mlir/include/mlir/Tools/PDLL/CodeGen/CPPGen.h
    mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
    mlir/test/mlir-pdll/CodeGen/CPP/general.pdll

Modified: 
    mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt
    mlir/tools/mlir-pdll/mlir-pdll.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Tools/PDLL/CodeGen/CPPGen.h b/mlir/include/mlir/Tools/PDLL/CodeGen/CPPGen.h
new file mode 100644
index 0000000000000..58f4cef48baf7
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/CodeGen/CPPGen.h
@@ -0,0 +1,28 @@
+//===- CPPGen.h - MLIR PDLL CPP Code Generation -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_CODEGEN_CPPGEN_H_
+#define MLIR_TOOLS_PDLL_CODEGEN_CPPGEN_H_
+
+#include "mlir/Support/LLVM.h"
+#include <memory>
+
+namespace mlir {
+class ModuleOp;
+
+namespace pdll {
+namespace ast {
+class Module;
+} // namespace ast
+
+void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
+                      raw_ostream &os);
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PDLL_CODEGEN_CPPGEN_H_

diff  --git a/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt
index f1e59126623fc..9226beefe3303 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_library(MLIRPDLLCodeGen
+  CPPGen.cpp
   MLIRGen.cpp
 
   LINK_LIBS PUBLIC

diff  --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
new file mode 100644
index 0000000000000..d5045ca07cb16
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -0,0 +1,219 @@
+//===- CPPGen.cpp ---------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files contains a PDLL generator that outputs C++ code that defines PDLL
+// patterns as individual C++ PDLPatternModules for direct use in native code,
+// and also defines any native constraints whose bodies were defined in PDLL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Tools/PDLL/AST/Nodes.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+using namespace mlir::pdll;
+
+//===----------------------------------------------------------------------===//
+// CodeGen
+//===----------------------------------------------------------------------===//
+
+namespace {
+class CodeGen {
+public:
+  CodeGen(raw_ostream &os) : os(os) {}
+
+  /// Generate C++ code for the given PDL pattern module.
+  void generate(const ast::Module &astModule, ModuleOp module);
+
+private:
+  void generate(pdl::PatternOp pattern, StringRef patternName,
+                StringSet<> &nativeFunctions);
+
+  /// Generate C++ code for all user defined constraints and rewrites with
+  /// native code.
+  void generateConstraintAndRewrites(const ast::Module &astModule,
+                                     ModuleOp module,
+                                     StringSet<> &nativeFunctions);
+  void generate(const ast::UserConstraintDecl *decl,
+                StringSet<> &nativeFunctions);
+  void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
+  void generateConstraintOrRewrite(StringRef name, bool isConstraint,
+                                   ArrayRef<ast::VariableDecl *> inputs,
+                                   StringRef codeBlock,
+                                   StringSet<> &nativeFunctions);
+
+  /// The stream to output to.
+  raw_ostream &os;
+};
+} // namespace
+
+void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
+  SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
+  StringSet<> nativeFunctions;
+
+  // Generate code for any native functions within the module.
+  generateConstraintAndRewrites(astModule, module, nativeFunctions);
+
+  os << "namespace {\n";
+  std::string basePatternName = "GeneratedPDLLPattern";
+  int patternIndex = 0;
+  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
+    // If the pattern has a name, use that. Otherwise, generate a unique name.
+    if (Optional<StringRef> patternName = pattern.sym_name()) {
+      patternNames.insert(patternName->str());
+    } else {
+      std::string name;
+      do {
+        name = (basePatternName + Twine(patternIndex++)).str();
+      } while (!patternNames.insert(name));
+    }
+
+    generate(pattern, patternNames.back(), nativeFunctions);
+  }
+  os << "} // end namespace\n\n";
+
+  // Emit function to add the generated matchers to the pattern list.
+  os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
+        "::mlir::RewritePatternSet &patterns) {\n";
+  for (const auto &name : patternNames)
+    os << "  patterns.add<" << name << ">(patterns.getContext());\n";
+  os << "}\n";
+}
+
+void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
+                       StringSet<> &nativeFunctions) {
+  const char *patternClassStartStr = R"(
+struct {0} : ::mlir::PDLPatternModule {{
+  {0}(::mlir::MLIRContext *context)
+    : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
+)";
+  os << llvm::formatv(patternClassStartStr, patternName);
+
+  os << "R\"mlir(";
+  pattern->print(os, OpPrintingFlags().enableDebugInfo());
+  os << "\n    )mlir\", context)) {\n";
+
+  // Register any native functions used within the pattern.
+  StringSet<> registeredNativeFunctions;
+  auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
+    if (!nativeFunctions.count(fnName) ||
+        !registeredNativeFunctions.insert(fnName).second)
+      return;
+    os << "    register" << fnType << "Function(\"" << fnName << "\", "
+       << fnName << "PDLFn);\n";
+  };
+  pattern.walk([&](Operation *op) {
+    if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
+      checkRegisterNativeFn(constraintOp.name(), "Constraint");
+    else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
+      checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
+  });
+  os << "  }\n};\n\n";
+}
+
+void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
+                                            ModuleOp module,
+                                            StringSet<> &nativeFunctions) {
+  // First check to see which constraints and rewrites are actually referenced
+  // in the module.
+  StringSet<> usedFns;
+  module.walk([&](Operation *op) {
+    TypeSwitch<Operation *>(op)
+        .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
+            [&](auto op) { usedFns.insert(op.name()); });
+  });
+
+  for (const ast::Decl *decl : astModule.getChildren()) {
+    TypeSwitch<const ast::Decl *>(decl)
+        .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
+            [&](const auto *decl) {
+              // We only generate code for inline native decls that have been
+              // referenced.
+              if (decl->getCodeBlock() &&
+                  usedFns.contains(decl->getName().getName()))
+                this->generate(decl, nativeFunctions);
+            });
+  }
+}
+
+void CodeGen::generate(const ast::UserConstraintDecl *decl,
+                       StringSet<> &nativeFunctions) {
+  return generateConstraintOrRewrite(decl->getName().getName(),
+                                     /*isConstraint=*/true, decl->getInputs(),
+                                     *decl->getCodeBlock(), nativeFunctions);
+}
+
+void CodeGen::generate(const ast::UserRewriteDecl *decl,
+                       StringSet<> &nativeFunctions) {
+  return generateConstraintOrRewrite(decl->getName().getName(),
+                                     /*isConstraint=*/false, decl->getInputs(),
+                                     *decl->getCodeBlock(), nativeFunctions);
+}
+
+void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
+                                          ArrayRef<ast::VariableDecl *> inputs,
+                                          StringRef codeBlock,
+                                          StringSet<> &nativeFunctions) {
+  nativeFunctions.insert(name);
+
+  // TODO: Should there be something explicit for handling optionality?
+  auto getCppType = [&](ast::Type type) -> StringRef {
+    return llvm::TypeSwitch<ast::Type, StringRef>(type)
+        .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
+        .Case([&](ast::OperationType) {
+          // TODO: Allow using the derived Op class when possible.
+          return "::mlir::Operation *";
+        })
+        .Case([&](ast::TypeType) { return "::mlir::Type"; })
+        .Case([&](ast::ValueType) { return "::mlir::Value"; })
+        .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
+        .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
+  };
+
+  // FIXME: We currently do not have a modeling for the "constant params"
+  // support PDL provides. We should either figure out a modeling for this, or
+  // refactor the support within PDL to be something a bit more reasonable for
+  // what we need as a frontend.
+  os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
+     << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
+        "::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter"
+     << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
+
+  const char *argumentInitStr = R"(
+  {0} {1} = {{};
+  if (values[{2}])
+    {1} = values[{2}].cast<{0}>();
+  (void){1};
+)";
+  for (const auto &it : llvm::enumerate(inputs)) {
+    const ast::VariableDecl *input = it.value();
+    os << llvm::formatv(argumentInitStr, getCppType(input->getType()),
+                        input->getName().getName(), it.index());
+  }
+
+  os << "  " << codeBlock.trim() << "\n}\n";
+}
+
+//===----------------------------------------------------------------------===//
+// CPPGen
+//===----------------------------------------------------------------------===//
+
+void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
+                                  raw_ostream &os) {
+  CodeGen codegen(os);
+  codegen.generate(astModule, module);
+}

diff  --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
new file mode 100644
index 0000000000000..5b94c66b43500
--- /dev/null
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -0,0 +1,105 @@
+// RUN: mlir-pdll %s -I %S -split-input-file -x cpp | FileCheck %s
+
+// Check that we generate a wrapper pattern for each PDL pattern. Also
+// add in a pattern awkwardly named the same as our generated patterns to
+// check that we handle overlap.
+
+// CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule {
+// CHECK:  : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
+// CHECK:  R"mlir(
+// CHECK:    pdl.pattern
+// CHECK:      operation "test.op"
+// CHECK:  )mlir", context))
+
+// CHECK: struct NamedPattern : ::mlir::PDLPatternModule {
+// CHECK:  : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
+// CHECK:  R"mlir(
+// CHECK:    pdl.pattern
+// CHECK:      operation "test.op2"
+// CHECK:  )mlir", context))
+
+// CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule {
+
+// CHECK: struct GeneratedPDLLPattern2 : ::mlir::PDLPatternModule {
+// CHECK:  : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
+// CHECK:  R"mlir(
+// CHECK:    pdl.pattern
+// CHECK:      operation "test.op3"
+// CHECK:  )mlir", context))
+
+// CHECK:      static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns) {
+// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern0>(patterns.getContext());
+// CHECK-NEXT:   patterns.add<NamedPattern>(patterns.getContext());
+// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern1>(patterns.getContext());
+// CHECK-NEXT:   patterns.add<GeneratedPDLLPattern2>(patterns.getContext());
+// CHECK-NEXT: }
+
+Pattern => erase op<test.op>;
+Pattern NamedPattern => erase op<test.op2>;
+Pattern GeneratedPDLLPattern1 => erase op<>;
+Pattern => erase op<test.op3>;
+
+// -----
+
+// Check the generation of native constraints and rewrites.
+
+// CHECK:      static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams,
+// CHECK-SAME:                                           ::mlir::PatternRewriter &rewriter) {
+// CHECK:   ::mlir::Attribute attr = {};
+// CHECK:   if (values[0])
+// CHECK:     attr = values[0].cast<::mlir::Attribute>();
+// CHECK:   ::mlir::Operation * op = {};
+// CHECK:   if (values[1])
+// CHECK:     op = values[1].cast<::mlir::Operation *>();
+// CHECK:   ::mlir::Type type = {};
+// CHECK:   if (values[2])
+// CHECK:     type = values[2].cast<::mlir::Type>();
+// CHECK:   ::mlir::Value value = {};
+// CHECK:   if (values[3])
+// CHECK:     value = values[3].cast<::mlir::Value>();
+// CHECK:   ::mlir::TypeRange typeRange = {};
+// CHECK:   if (values[4])
+// CHECK:     typeRange = values[4].cast<::mlir::TypeRange>();
+// CHECK:   ::mlir::ValueRange valueRange = {};
+// CHECK:   if (values[5])
+// CHECK:     valueRange = values[5].cast<::mlir::ValueRange>();
+
+// CHECK:   return success();
+// CHECK: }
+
+// CHECK-NOT: TestUnusedCst
+
+// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams,
+// CHECK-SAME:                         ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) {
+// CHECK:   ::mlir::Attribute attr = {};
+// CHECK:   ::mlir::Operation * op = {};
+// CHECK:   ::mlir::Type type = {};
+// CHECK:   ::mlir::Value value = {};
+// CHECK:   ::mlir::TypeRange typeRange = {};
+// CHECK:   ::mlir::ValueRange valueRange = {};
+
+// CHECK: foo;
+// CHECK: }
+
+// CHECK-NOT: TestUnusedRewrite
+
+// CHECK: struct TestCstAndRewrite : ::mlir::PDLPatternModule {
+// CHECK:   registerConstraintFunction("TestCst", TestCstPDLFn);
+// CHECK:   registerRewriteFunction("TestRewrite", TestRewritePDLFn);
+
+Constraint TestCst(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{
+  return success();
+}];
+Constraint TestUnusedCst() [{ return success(); }];
+
+Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ foo; }];
+Rewrite TestUnusedRewrite(op: Op) [{}];
+
+Pattern TestCstAndRewrite {
+  let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
+  TestCst(attr<"true">, root, type, operand, types, operands);
+  rewrite root with {
+    TestRewrite(attr<"true">, root, type, operand, types, operands);
+    erase root;
+  };
+}

diff  --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp
index 6ee2b8bf11a55..e133d9e45c54a 100644
--- a/mlir/tools/mlir-pdll/mlir-pdll.cpp
+++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Support/ToolUtilities.h"
 #include "mlir/Tools/PDLL/AST/Context.h"
 #include "mlir/Tools/PDLL/AST/Nodes.h"
+#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
 #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
 #include "mlir/Tools/PDLL/Parser/Parser.h"
 #include "llvm/Support/CommandLine.h"
@@ -29,6 +30,7 @@ using namespace mlir::pdll;
 enum class OutputType {
   AST,
   MLIR,
+  CPP,
 };
 
 static LogicalResult
@@ -54,7 +56,12 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
   if (!pdlModule)
     return failure();
 
-  pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
+  if (outputType == OutputType::MLIR) {
+    pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
+    return success();
+  }
+
+  codegenPDLLToCPP(**module, *pdlModule, os);
   return success();
 }
 
@@ -82,7 +89,10 @@ int main(int argc, char **argv) {
       llvm::cl::values(clEnumValN(OutputType::AST, "ast",
                                   "generate the AST for the input file"),
                        clEnumValN(OutputType::MLIR, "mlir",
-                                  "generate the PDL MLIR for the input file")));
+                                  "generate the PDL MLIR for the input file"),
+                       clEnumValN(OutputType::CPP, "cpp",
+                                  "generate a C++ source file containing the "
+                                  "patterns for the input file")));
 
   llvm::InitLLVM y(argc, argv);
   llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend");


        


More information about the Mlir-commits mailing list