[llvm] [mlir] [lldb] [mlir] Introduce replaceWithZeroTripCheck in LoopLikeOpInterface (PR #80331)

Jerry Wu via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 14:59:22 PST 2024


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/80331

>From 70f54b51bef87bde5e3f5ee067c0f2414d34e915 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 1 Feb 2024 19:57:26 +0000
Subject: [PATCH 1/6] Add replaceWithZeroTripCheck to LoopLikeOpInterface

---
 .../mlir/Interfaces/LoopLikeInterface.td      | 22 +++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index e2ac85a3f7725..77409cb3a8274 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -220,6 +220,28 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
+    >,
+    InterfaceMethod<[{
+        Add a zero-trip-check around the loop to check if the loop body is ever
+        run and return the new loop inside the check. The loop body is moved
+        over to the new loop. Returns "failure" if the loop doesn't support
+        this transformation.
+
+        After the transformation, the ops inserted to the parent region of the
+        loop are guaranteed to be run only if the loop body runs at least one
+        iteration.
+
+        Note: Ops in the loop body might be rearranged because of loop rotating
+        to maintain the semantic. Terminators might be removed/added during this
+        transformation.
+      }],
+      /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
+      /*methodName=*/"replaceWithZeroTripCheck",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
     >
   ];
 

>From d6703ebbeb5ddc358929672b44994a9d05683523 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 2 Feb 2024 18:59:03 +0000
Subject: [PATCH 2/6] Add tests

---
 mlir/unittests/Interfaces/CMakeLists.txt      |   3 +
 .../Interfaces/LoopLikeInterfaceTest.cpp      | 101 ++++++++++++++++++
 2 files changed, 104 insertions(+)
 create mode 100644 mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp

diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index d192b2922d6b9..cab9503cf295b 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRInterfacesTests
   DataLayoutInterfacesTest.cpp
   InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
+  LoopLikeInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
@@ -12,7 +13,9 @@ target_link_libraries(MLIRInterfacesTests
   MLIRDataLayoutInterfaces
   MLIRDLTIDialect
   MLIRFuncDialect
+  MLIRIR
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
+  MLIRLoopLikeInterface
   MLIRParser
 )
diff --git a/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp b/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
new file mode 100644
index 0000000000000..b0b7680fed68e
--- /dev/null
+++ b/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
@@ -0,0 +1,101 @@
+//===- LoopLikeInterfaceTest.cpp - Unit tests for Loop Like Interfaces. ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Parser/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+
+struct NoZeroTripCheckLoopOp
+    : public Op<NoZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() {
+    return "looptest.no_zero_trip_check_loop_op";
+  }
+
+  SmallVector<Region *> getLoopRegions() { return {}; }
+};
+
+struct ImplZeroTripCheckLoopOp
+    : public Op<ImplZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() {
+    return "looptest.impl_zero_trip_check_loop_op";
+  }
+
+  SmallVector<Region *> getLoopRegions() { return {}; }
+
+  FailureOr<LoopLikeOpInterface>
+  replaceWithZeroTripCheck(RewriterBase &rewriter) {
+    return cast<LoopLikeOpInterface>(this->getOperation());
+  }
+};
+
+/// A dialect putting all the above together.
+struct LoopTestDialect : Dialect {
+  explicit LoopTestDialect(MLIRContext *ctx)
+      : Dialect(getDialectNamespace(), ctx, TypeID::get<LoopTestDialect>()) {
+    addOperations<NoZeroTripCheckLoopOp, ImplZeroTripCheckLoopOp>();
+  }
+  static StringRef getDialectNamespace() { return "looptest"; }
+};
+
+TEST(LoopLikeOpInterface, NoReplaceWithZeroTripCheck) {
+  const char *ir = R"MLIR(
+  "looptest.no_zero_trip_check_loop_op"() : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<LoopTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  LoopLikeOpInterface testOp =
+      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
+
+  IRRewriter rewriter(&ctx);
+  FailureOr<LoopLikeOpInterface> result =
+      testOp.replaceWithZeroTripCheck(rewriter);
+
+  EXPECT_TRUE(failed(result));
+}
+
+TEST(LoopLikeOpInterface, ImplReplaceWithZeroTripCheck) {
+  const char *ir = R"MLIR(
+  "looptest.impl_zero_trip_check_loop_op"() : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<LoopTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  LoopLikeOpInterface testOp =
+      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
+
+  IRRewriter rewriter(&ctx);
+  FailureOr<LoopLikeOpInterface> result =
+      testOp.replaceWithZeroTripCheck(rewriter);
+
+  EXPECT_TRUE(succeeded(result));
+  EXPECT_EQ(*result, testOp);
+}

>From 43af5b4ff01b03cc8b69ba494d0f344ab645e3d1 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 2 Feb 2024 19:12:14 +0000
Subject: [PATCH 3/6] Update comments

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 77409cb3a8274..81f202cf34186 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -233,7 +233,10 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
 
         Note: Ops in the loop body might be rearranged because of loop rotating
         to maintain the semantic. Terminators might be removed/added during this
-        transformation.
+        transformation. Also callers are not required to check the side-effect
+        of loop condition, so the transformation needs to consider that to make
+        sure the loop behavior is unchanged when moving the condtion out of the
+        loop for the zero-trip-check.
       }],
       /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
       /*methodName=*/"replaceWithZeroTripCheck",

>From 8e244d44e1939641b35ea2f550b73847d1cde26c Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 5 Feb 2024 21:39:51 +0000
Subject: [PATCH 4/6] Update comments

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 81f202cf34186..572845f46d320 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -223,9 +223,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     >,
     InterfaceMethod<[{
         Add a zero-trip-check around the loop to check if the loop body is ever
-        run and return the new loop inside the check. The loop body is moved
-        over to the new loop. Returns "failure" if the loop doesn't support
-        this transformation.
+        run and return the same loop (moved) or a new loop (replaced) inside the
+        check. Returns "failure" if the loop doesn't support the transformation.
 
         After the transformation, the ops inserted to the parent region of the
         loop are guaranteed to be run only if the loop body runs at least one
@@ -235,7 +234,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         to maintain the semantic. Terminators might be removed/added during this
         transformation. Also callers are not required to check the side-effect
         of loop condition, so the transformation needs to consider that to make
-        sure the loop behavior is unchanged when moving the condtion out of the
+        sure the loop behavior is unchanged when moving the condition out of the
         loop for the zero-trip-check.
       }],
       /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",

>From a985663735c8ac5cdc7582347e6888f076e6c0fb Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 5 Feb 2024 22:58:38 +0000
Subject: [PATCH 5/6] Add test pass

---
 .../LoopLikeInterface/CMakeLists.txt          |  1 +
 .../TestLoopZeroTripCheck.cpp                 | 52 +++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 3 files changed, 55 insertions(+)
 create mode 100644 mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp

diff --git a/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt
index f20219e00cb86..19a727822dc67 100644
--- a/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt
+++ b/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRLoopLikeInterfaceTestPasses
   TestBlockInLoop.cpp
+  TestLoopZeroTripCheck.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp b/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp
new file mode 100644
index 0000000000000..45f0f312aceea
--- /dev/null
+++ b/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp
@@ -0,0 +1,52 @@
+//===- TestLoopZeroTripCheck.cpp.cpp -- Pass to test replaceWithZeroTripC--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the passes to test replaceWithZeroTripCheck of loop ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLoopZeroTripCheck
+    : public PassWrapper<TestLoopZeroTripCheck, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopZeroTripCheck)
+
+  StringRef getArgument() const final { return "test-loop-zero-trip-check"; }
+  StringRef getDescription() const final {
+    return "test replaceWithZeroTripCheck of loop ops";
+  }
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    MLIRContext *context = &getContext();
+    IRRewriter rewriter(context);
+    func.walk([&](LoopLikeOpInterface op) {
+      auto result = op.replaceWithZeroTripCheck(rewriter);
+      if (failed(result)) {
+        // Ignore failures (e.g. not implemented) in tests.
+      }
+    });
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLoopZeroTripCheckPass() {
+  PassRegistration<TestLoopZeroTripCheck>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e09..6ac3283bcb9d1 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -110,6 +110,7 @@ void registerTestLoopFusion();
 void registerTestCFGLoopInfoPass();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
+void registerTestLoopZeroTripCheckPass();
 void registerTestLowerToLLVM();
 void registerTestLowerToNVVM();
 void registerTestMakeIsolatedFromAbovePass();
@@ -234,6 +235,7 @@ void registerTestPasses() {
   mlir::test::registerTestCFGLoopInfoPass();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
+  mlir::test::registerTestLoopZeroTripCheckPass();
   mlir::test::registerTestLowerToLLVM();
   mlir::test::registerTestMakeIsolatedFromAbovePass();
   mlir::test::registerTestMatchReductionPass();

>From 6783afd3c4c2a9c3e40c5b05a63326f0c57ffcde Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 5 Feb 2024 22:59:03 +0000
Subject: [PATCH 6/6] Revert "Add tests"

This reverts commit d6703ebbeb5ddc358929672b44994a9d05683523.
---
 mlir/unittests/Interfaces/CMakeLists.txt      |   3 -
 .../Interfaces/LoopLikeInterfaceTest.cpp      | 101 ------------------
 2 files changed, 104 deletions(-)
 delete mode 100644 mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp

diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index cab9503cf295b..d192b2922d6b9 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_unittest(MLIRInterfacesTests
   DataLayoutInterfacesTest.cpp
   InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
-  LoopLikeInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
@@ -13,9 +12,7 @@ target_link_libraries(MLIRInterfacesTests
   MLIRDataLayoutInterfaces
   MLIRDLTIDialect
   MLIRFuncDialect
-  MLIRIR
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
-  MLIRLoopLikeInterface
   MLIRParser
 )
diff --git a/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp b/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
deleted file mode 100644
index b0b7680fed68e..0000000000000
--- a/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
+++ /dev/null
@@ -1,101 +0,0 @@
-//===- LoopLikeInterfaceTest.cpp - Unit tests for Loop Like Interfaces. ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/LoopLikeInterface.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Parser/Parser.h"
-
-#include <gtest/gtest.h>
-
-using namespace mlir;
-
-struct NoZeroTripCheckLoopOp
-    : public Op<NoZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
-  using Op::Op;
-
-  static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
-  static StringRef getOperationName() {
-    return "looptest.no_zero_trip_check_loop_op";
-  }
-
-  SmallVector<Region *> getLoopRegions() { return {}; }
-};
-
-struct ImplZeroTripCheckLoopOp
-    : public Op<ImplZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
-  using Op::Op;
-
-  static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
-  static StringRef getOperationName() {
-    return "looptest.impl_zero_trip_check_loop_op";
-  }
-
-  SmallVector<Region *> getLoopRegions() { return {}; }
-
-  FailureOr<LoopLikeOpInterface>
-  replaceWithZeroTripCheck(RewriterBase &rewriter) {
-    return cast<LoopLikeOpInterface>(this->getOperation());
-  }
-};
-
-/// A dialect putting all the above together.
-struct LoopTestDialect : Dialect {
-  explicit LoopTestDialect(MLIRContext *ctx)
-      : Dialect(getDialectNamespace(), ctx, TypeID::get<LoopTestDialect>()) {
-    addOperations<NoZeroTripCheckLoopOp, ImplZeroTripCheckLoopOp>();
-  }
-  static StringRef getDialectNamespace() { return "looptest"; }
-};
-
-TEST(LoopLikeOpInterface, NoReplaceWithZeroTripCheck) {
-  const char *ir = R"MLIR(
-  "looptest.no_zero_trip_check_loop_op"() : () -> ()
-  )MLIR";
-
-  DialectRegistry registry;
-  registry.insert<LoopTestDialect>();
-  MLIRContext ctx(registry);
-
-  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
-  LoopLikeOpInterface testOp =
-      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
-
-  IRRewriter rewriter(&ctx);
-  FailureOr<LoopLikeOpInterface> result =
-      testOp.replaceWithZeroTripCheck(rewriter);
-
-  EXPECT_TRUE(failed(result));
-}
-
-TEST(LoopLikeOpInterface, ImplReplaceWithZeroTripCheck) {
-  const char *ir = R"MLIR(
-  "looptest.impl_zero_trip_check_loop_op"() : () -> ()
-  )MLIR";
-
-  DialectRegistry registry;
-  registry.insert<LoopTestDialect>();
-  MLIRContext ctx(registry);
-
-  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
-  LoopLikeOpInterface testOp =
-      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
-
-  IRRewriter rewriter(&ctx);
-  FailureOr<LoopLikeOpInterface> result =
-      testOp.replaceWithZeroTripCheck(rewriter);
-
-  EXPECT_TRUE(succeeded(result));
-  EXPECT_EQ(*result, testOp);
-}



More information about the llvm-commits mailing list