[Mlir-commits] [mlir] 628e136 - [mlir][ODS] Fix copy ctor for generate Pass classes

Vladislav Vinogradov llvmlistbot at llvm.org
Mon Jun 21 04:06:55 PDT 2021


Author: Vladislav Vinogradov
Date: 2021-06-21T14:07:31+03:00
New Revision: 628e136738823c3f69f4d0fb406c78f017009d1e

URL: https://github.com/llvm/llvm-project/commit/628e136738823c3f69f4d0fb406c78f017009d1e
DIFF: https://github.com/llvm/llvm-project/commit/628e136738823c3f69f4d0fb406c78f017009d1e.diff

LOG: [mlir][ODS] Fix copy ctor for generate Pass classes

Redirect the copy ctor to the actual class instead of
overwriting it with `TypeID` based ctor.

This allows the final Pass classes to have extra fields and logic for their copy.

Reviewed By: lattner

Differential Revision: https://reviews.llvm.org/D104302

Added: 
    mlir/unittests/TableGen/PassGenTest.cpp
    mlir/unittests/TableGen/passes.td

Modified: 
    mlir/include/mlir/Pass/Pass.h
    mlir/tools/mlir-tblgen/PassGen.cpp
    mlir/unittests/TableGen/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index da91ef303cd1d..3a9865d669dd4 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -339,6 +339,7 @@ class Pass {
 template <typename OpT = void> class OperationPass : public Pass {
 protected:
   OperationPass(TypeID passID) : Pass(passID, OpT::getOperationName()) {}
+  OperationPass(const OperationPass &) = default;
 
   /// Support isa/dyn_cast functionality.
   static bool classof(const Pass *pass) {
@@ -371,6 +372,7 @@ template <typename OpT = void> class OperationPass : public Pass {
 template <> class OperationPass<void> : public Pass {
 protected:
   OperationPass(TypeID passID) : Pass(passID) {}
+  OperationPass(const OperationPass &) = default;
 };
 
 /// A model for providing function pass specific utilities.
@@ -409,6 +411,7 @@ template <typename PassT, typename BaseT> class PassWrapper : public BaseT {
 
 protected:
   PassWrapper() : BaseT(TypeID::get<PassT>()) {}
+  PassWrapper(const PassWrapper &) = default;
 
   /// Returns the derived pass name.
   StringRef getName() const override { return llvm::getTypeName<PassT>(); }

diff  --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index 8f3a19daaa5fb..e09746bdfb4ad 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -48,7 +48,7 @@ class {0}Base : public {1} {
   using Base = {0}Base;
 
   {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
-  {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
+  {0}Base(const {0}Base &other) : {1}(other) {{}
 
   /// Returns the command-line argument attached to this pass.
   static constexpr ::llvm::StringLiteral getArgumentName() {

diff  --git a/mlir/unittests/TableGen/CMakeLists.txt b/mlir/unittests/TableGen/CMakeLists.txt
index 7cee691de4006..421133be473fa 100644
--- a/mlir/unittests/TableGen/CMakeLists.txt
+++ b/mlir/unittests/TableGen/CMakeLists.txt
@@ -8,15 +8,21 @@ mlir_tablegen(StructAttrGenTest.h.inc -gen-struct-attr-decls)
 mlir_tablegen(StructAttrGenTest.cpp.inc -gen-struct-attr-defs)
 add_public_tablegen_target(MLIRTableGenStructAttrIncGen)
 
+set(LLVM_TARGET_DEFINITIONS passes.td)
+mlir_tablegen(PassGenTest.h.inc -gen-pass-decls -name TableGenTest)
+add_public_tablegen_target(MLIRTableGenTestPassIncGen)
+
 add_mlir_unittest(MLIRTableGenTests
   EnumsGenTest.cpp
   StructsGenTest.cpp
   FormatTest.cpp
   OpBuildGen.cpp
+  PassGenTest.cpp
 )
 
 add_dependencies(MLIRTableGenTests MLIRTableGenEnumsIncGen)
 add_dependencies(MLIRTableGenTests MLIRTableGenStructAttrIncGen)
+add_dependencies(MLIRTableGenTests MLIRTableGenTestPassIncGen)
 add_dependencies(MLIRTableGenTests MLIRTestDialect)
 
 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../test/lib/Dialect/Test)

diff  --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp
new file mode 100644
index 0000000000000..33bd1606051f4
--- /dev/null
+++ b/mlir/unittests/TableGen/PassGenTest.cpp
@@ -0,0 +1,48 @@
+//===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+
+#include "gmock/gmock.h"
+
+std::unique_ptr<mlir::Pass> createTestPass(int v = 0);
+
+#define GEN_PASS_REGISTRATION
+#include "PassGenTest.h.inc"
+
+#define GEN_PASS_CLASSES
+#include "PassGenTest.h.inc"
+
+struct TestPass : public TestPassBase<TestPass> {
+  explicit TestPass(int v) : extraVal(v) {}
+
+  void runOnOperation() override {}
+
+  std::unique_ptr<mlir::Pass> clone() const {
+    return TestPassBase<TestPass>::clone();
+  }
+
+  int extraVal;
+};
+
+std::unique_ptr<mlir::Pass> createTestPass(int v) {
+  return std::make_unique<TestPass>(v);
+}
+
+TEST(PassGenTest, PassClone) {
+  mlir::MLIRContext context;
+
+  const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
+    return static_cast<const TestPass *>(pass.get());
+  };
+
+  const auto origPass = createTestPass(10);
+  const auto clonePass = unwrap(origPass)->clone();
+
+  EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);
+}

diff  --git a/mlir/unittests/TableGen/passes.td b/mlir/unittests/TableGen/passes.td
new file mode 100644
index 0000000000000..f730390ebc8eb
--- /dev/null
+++ b/mlir/unittests/TableGen/passes.td
@@ -0,0 +1,19 @@
+//===-- passes.td - PassGen test definition file -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+include "mlir/Pass/PassBase.td"
+include "mlir/Pass/PassBase.td"
+include "mlir/Rewrite/PassUtil.td"
+
+def TestPass : Pass<"test"> {
+  let summary = "Test pass";
+
+  let constructor = "::createTestPass()";
+
+  let options = RewritePassUtils.options;
+}


        


More information about the Mlir-commits mailing list