[Mlir-commits] [mlir] [mlir][WIP] Implement replaceWithZeroTripCheck for scf.while (PR #80349)

Jerry Wu llvmlistbot at llvm.org
Tue Feb 6 11:22:28 PST 2024


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/80349

>From c67e04074c262aadf9804d1bb0f14eef52f4077f Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 1 Feb 2024 19:57:26 +0000
Subject: [PATCH 01/10] Add replaceWithZeroTripCheck to LoopLikeOpInterface

---
 .../mlir/Interfaces/LoopLikeInterface.td      | 22 +++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index e2ac85a3f7725d..77409cb3a8274b 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -220,6 +220,28 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
+    >,
+    InterfaceMethod<[{
+        Add a zero-trip-check around the loop to check if the loop body is ever
+        run and return the new loop inside the check. The loop body is moved
+        over to the new loop. Returns "failure" if the loop doesn't support
+        this transformation.
+
+        After the transformation, the ops inserted to the parent region of the
+        loop are guaranteed to be run only if the loop body runs at least one
+        iteration.
+
+        Note: Ops in the loop body might be rearranged because of loop rotating
+        to maintain the semantic. Terminators might be removed/added during this
+        transformation.
+      }],
+      /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
+      /*methodName=*/"replaceWithZeroTripCheck",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
     >
   ];
 

>From 059df4e60bdc0ce4dd9d36b78f89c34daf262ded Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 2 Feb 2024 18:59:03 +0000
Subject: [PATCH 02/10] Add tests

---
 mlir/unittests/Interfaces/CMakeLists.txt      |   3 +
 .../Interfaces/LoopLikeInterfaceTest.cpp      | 101 ++++++++++++++++++
 2 files changed, 104 insertions(+)
 create mode 100644 mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp

diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index d192b2922d6b9d..cab9503cf295b1 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRInterfacesTests
   DataLayoutInterfacesTest.cpp
   InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
+  LoopLikeInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
@@ -12,7 +13,9 @@ target_link_libraries(MLIRInterfacesTests
   MLIRDataLayoutInterfaces
   MLIRDLTIDialect
   MLIRFuncDialect
+  MLIRIR
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
+  MLIRLoopLikeInterface
   MLIRParser
 )
diff --git a/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp b/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
new file mode 100644
index 00000000000000..b0b7680fed68e7
--- /dev/null
+++ b/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
@@ -0,0 +1,101 @@
+//===- LoopLikeInterfaceTest.cpp - Unit tests for Loop Like Interfaces. ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Parser/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+
+struct NoZeroTripCheckLoopOp
+    : public Op<NoZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() {
+    return "looptest.no_zero_trip_check_loop_op";
+  }
+
+  SmallVector<Region *> getLoopRegions() { return {}; }
+};
+
+struct ImplZeroTripCheckLoopOp
+    : public Op<ImplZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static StringRef getOperationName() {
+    return "looptest.impl_zero_trip_check_loop_op";
+  }
+
+  SmallVector<Region *> getLoopRegions() { return {}; }
+
+  FailureOr<LoopLikeOpInterface>
+  replaceWithZeroTripCheck(RewriterBase &rewriter) {
+    return cast<LoopLikeOpInterface>(this->getOperation());
+  }
+};
+
+/// A dialect putting all the above together.
+struct LoopTestDialect : Dialect {
+  explicit LoopTestDialect(MLIRContext *ctx)
+      : Dialect(getDialectNamespace(), ctx, TypeID::get<LoopTestDialect>()) {
+    addOperations<NoZeroTripCheckLoopOp, ImplZeroTripCheckLoopOp>();
+  }
+  static StringRef getDialectNamespace() { return "looptest"; }
+};
+
+TEST(LoopLikeOpInterface, NoReplaceWithZeroTripCheck) {
+  const char *ir = R"MLIR(
+  "looptest.no_zero_trip_check_loop_op"() : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<LoopTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  LoopLikeOpInterface testOp =
+      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
+
+  IRRewriter rewriter(&ctx);
+  FailureOr<LoopLikeOpInterface> result =
+      testOp.replaceWithZeroTripCheck(rewriter);
+
+  EXPECT_TRUE(failed(result));
+}
+
+TEST(LoopLikeOpInterface, ImplReplaceWithZeroTripCheck) {
+  const char *ir = R"MLIR(
+  "looptest.impl_zero_trip_check_loop_op"() : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<LoopTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+  LoopLikeOpInterface testOp =
+      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
+
+  IRRewriter rewriter(&ctx);
+  FailureOr<LoopLikeOpInterface> result =
+      testOp.replaceWithZeroTripCheck(rewriter);
+
+  EXPECT_TRUE(succeeded(result));
+  EXPECT_EQ(*result, testOp);
+}

>From 2cf88e0ac78d84e908c32bfdf93ed7cd0693df7d Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 2 Feb 2024 19:12:14 +0000
Subject: [PATCH 03/10] Update comments

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 77409cb3a8274b..81f202cf341864 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -233,7 +233,10 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
 
         Note: Ops in the loop body might be rearranged because of loop rotating
         to maintain the semantic. Terminators might be removed/added during this
-        transformation.
+        transformation. Also callers are not required to check the side-effect
+        of loop condition, so the transformation needs to consider that to make
+        sure the loop behavior is unchanged when moving the condtion out of the
+        loop for the zero-trip-check.
       }],
       /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
       /*methodName=*/"replaceWithZeroTripCheck",

>From ab65656260ec3325634b5196a18ed8caf8b87523 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 5 Feb 2024 21:39:51 +0000
Subject: [PATCH 04/10] Update comments

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 81f202cf341864..572845f46d320b 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -223,9 +223,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     >,
     InterfaceMethod<[{
         Add a zero-trip-check around the loop to check if the loop body is ever
-        run and return the new loop inside the check. The loop body is moved
-        over to the new loop. Returns "failure" if the loop doesn't support
-        this transformation.
+        run and return the same loop (moved) or a new loop (replaced) inside the
+        check. Returns "failure" if the loop doesn't support the transformation.
 
         After the transformation, the ops inserted to the parent region of the
         loop are guaranteed to be run only if the loop body runs at least one
@@ -235,7 +234,7 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         to maintain the semantic. Terminators might be removed/added during this
         transformation. Also callers are not required to check the side-effect
         of loop condition, so the transformation needs to consider that to make
-        sure the loop behavior is unchanged when moving the condtion out of the
+        sure the loop behavior is unchanged when moving the condition out of the
         loop for the zero-trip-check.
       }],
       /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",

>From bd919fb2d0ee6ff00c5090fa0e2c9dc44252e6be Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 5 Feb 2024 22:58:38 +0000
Subject: [PATCH 05/10] Add test pass

---
 .../LoopLikeInterface/CMakeLists.txt          |  1 +
 .../TestLoopZeroTripCheck.cpp                 | 52 +++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 3 files changed, 55 insertions(+)
 create mode 100644 mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp

diff --git a/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt
index f20219e00cb865..19a727822dc672 100644
--- a/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt
+++ b/mlir/test/lib/Interfaces/LoopLikeInterface/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_library(MLIRLoopLikeInterfaceTestPasses
   TestBlockInLoop.cpp
+  TestLoopZeroTripCheck.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp b/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp
new file mode 100644
index 00000000000000..45f0f312aceeab
--- /dev/null
+++ b/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp
@@ -0,0 +1,52 @@
+//===- TestLoopZeroTripCheck.cpp.cpp -- Pass to test replaceWithZeroTripC--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the passes to test replaceWithZeroTripCheck of loop ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestLoopZeroTripCheck
+    : public PassWrapper<TestLoopZeroTripCheck, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopZeroTripCheck)
+
+  StringRef getArgument() const final { return "test-loop-zero-trip-check"; }
+  StringRef getDescription() const final {
+    return "test replaceWithZeroTripCheck of loop ops";
+  }
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    MLIRContext *context = &getContext();
+    IRRewriter rewriter(context);
+    func.walk([&](LoopLikeOpInterface op) {
+      auto result = op.replaceWithZeroTripCheck(rewriter);
+      if (failed(result)) {
+        // Ignore failures (e.g. not implemented) in tests.
+      }
+    });
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLoopZeroTripCheckPass() {
+  PassRegistration<TestLoopZeroTripCheck>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e095..6ac3283bcb9d19 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -110,6 +110,7 @@ void registerTestLoopFusion();
 void registerTestCFGLoopInfoPass();
 void registerTestLoopMappingPass();
 void registerTestLoopUnrollingPass();
+void registerTestLoopZeroTripCheckPass();
 void registerTestLowerToLLVM();
 void registerTestLowerToNVVM();
 void registerTestMakeIsolatedFromAbovePass();
@@ -234,6 +235,7 @@ void registerTestPasses() {
   mlir::test::registerTestCFGLoopInfoPass();
   mlir::test::registerTestLoopMappingPass();
   mlir::test::registerTestLoopUnrollingPass();
+  mlir::test::registerTestLoopZeroTripCheckPass();
   mlir::test::registerTestLowerToLLVM();
   mlir::test::registerTestMakeIsolatedFromAbovePass();
   mlir::test::registerTestMatchReductionPass();

>From 9c2095013188c32936360a0eb44d89f3817a1f68 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 5 Feb 2024 22:59:03 +0000
Subject: [PATCH 06/10] Revert "Add tests"

This reverts commit d6703ebbeb5ddc358929672b44994a9d05683523.
---
 mlir/unittests/Interfaces/CMakeLists.txt      |   3 -
 .../Interfaces/LoopLikeInterfaceTest.cpp      | 101 ------------------
 2 files changed, 104 deletions(-)
 delete mode 100644 mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp

diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index cab9503cf295b1..d192b2922d6b9d 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,7 +3,6 @@ add_mlir_unittest(MLIRInterfacesTests
   DataLayoutInterfacesTest.cpp
   InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
-  LoopLikeInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
@@ -13,9 +12,7 @@ target_link_libraries(MLIRInterfacesTests
   MLIRDataLayoutInterfaces
   MLIRDLTIDialect
   MLIRFuncDialect
-  MLIRIR
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
-  MLIRLoopLikeInterface
   MLIRParser
 )
diff --git a/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp b/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
deleted file mode 100644
index b0b7680fed68e7..00000000000000
--- a/mlir/unittests/Interfaces/LoopLikeInterfaceTest.cpp
+++ /dev/null
@@ -1,101 +0,0 @@
-//===- LoopLikeInterfaceTest.cpp - Unit tests for Loop Like Interfaces. ---===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Interfaces/LoopLikeInterface.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Parser/Parser.h"
-
-#include <gtest/gtest.h>
-
-using namespace mlir;
-
-struct NoZeroTripCheckLoopOp
-    : public Op<NoZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
-  using Op::Op;
-
-  static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
-  static StringRef getOperationName() {
-    return "looptest.no_zero_trip_check_loop_op";
-  }
-
-  SmallVector<Region *> getLoopRegions() { return {}; }
-};
-
-struct ImplZeroTripCheckLoopOp
-    : public Op<ImplZeroTripCheckLoopOp, LoopLikeOpInterface::Trait> {
-  using Op::Op;
-
-  static ArrayRef<StringRef> getAttributeNames() { return {}; }
-
-  static StringRef getOperationName() {
-    return "looptest.impl_zero_trip_check_loop_op";
-  }
-
-  SmallVector<Region *> getLoopRegions() { return {}; }
-
-  FailureOr<LoopLikeOpInterface>
-  replaceWithZeroTripCheck(RewriterBase &rewriter) {
-    return cast<LoopLikeOpInterface>(this->getOperation());
-  }
-};
-
-/// A dialect putting all the above together.
-struct LoopTestDialect : Dialect {
-  explicit LoopTestDialect(MLIRContext *ctx)
-      : Dialect(getDialectNamespace(), ctx, TypeID::get<LoopTestDialect>()) {
-    addOperations<NoZeroTripCheckLoopOp, ImplZeroTripCheckLoopOp>();
-  }
-  static StringRef getDialectNamespace() { return "looptest"; }
-};
-
-TEST(LoopLikeOpInterface, NoReplaceWithZeroTripCheck) {
-  const char *ir = R"MLIR(
-  "looptest.no_zero_trip_check_loop_op"() : () -> ()
-  )MLIR";
-
-  DialectRegistry registry;
-  registry.insert<LoopTestDialect>();
-  MLIRContext ctx(registry);
-
-  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
-  LoopLikeOpInterface testOp =
-      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
-
-  IRRewriter rewriter(&ctx);
-  FailureOr<LoopLikeOpInterface> result =
-      testOp.replaceWithZeroTripCheck(rewriter);
-
-  EXPECT_TRUE(failed(result));
-}
-
-TEST(LoopLikeOpInterface, ImplReplaceWithZeroTripCheck) {
-  const char *ir = R"MLIR(
-  "looptest.impl_zero_trip_check_loop_op"() : () -> ()
-  )MLIR";
-
-  DialectRegistry registry;
-  registry.insert<LoopTestDialect>();
-  MLIRContext ctx(registry);
-
-  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
-  LoopLikeOpInterface testOp =
-      cast<LoopLikeOpInterface>(module->getBody()->getOperations().front());
-
-  IRRewriter rewriter(&ctx);
-  FailureOr<LoopLikeOpInterface> result =
-      testOp.replaceWithZeroTripCheck(rewriter);
-
-  EXPECT_TRUE(succeeded(result));
-  EXPECT_EQ(*result, testOp);
-}

>From 0e90f3934fa188d3c016776e610be9465dcc40c1 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Mon, 5 Feb 2024 23:03:06 +0000
Subject: [PATCH 07/10] Add missing mlir file

---
 .../Dialect/SCF/loop-zero-trip-check.mlir     | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)
 create mode 100644 mlir/test/Dialect/SCF/loop-zero-trip-check.mlir

diff --git a/mlir/test/Dialect/SCF/loop-zero-trip-check.mlir b/mlir/test/Dialect/SCF/loop-zero-trip-check.mlir
new file mode 100644
index 00000000000000..654dad896b56a1
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-zero-trip-check.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -test-loop-zero-trip-check -split-input-file  | FileCheck %s
+
+func.func @no_replace_scf_while_with_zero_trip_check(%bound : i32) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %cst5 = arith.constant 5 : i32
+  %res:2 = scf.while (%iter = %cst0) : (i32) -> (i32, i32) {
+    %cond = arith.cmpi slt, %iter, %bound : i32
+    %inv = arith.addi %bound, %cst5 : i32
+    scf.condition(%cond) %iter, %inv : i32, i32
+  } do {
+  ^bb0(%arg1: i32, %arg2: i32):
+    %next = arith.addi %arg1, %arg2 : i32
+    scf.yield %next : i32
+  }
+  return %res#0 : i32
+}
+
+// TODO(pzread): Update the test once the replaceZeroTripCheck is implemented.
+// CHECK-LABEL: func.func @no_replace_scf_while_with_zero_trip_check
+// CHECK-NOT:     scf.if
+// CHECK:         scf.while

>From b3cf941d43007b035fc53c6087054b12a8e93560 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Tue, 6 Feb 2024 19:12:52 +0000
Subject: [PATCH 08/10] Improve comments

---
 .../lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp b/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp
index 45f0f312aceeab..e908da6b03f11f 100644
--- a/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp
+++ b/mlir/test/lib/Interfaces/LoopLikeInterface/TestLoopZeroTripCheck.cpp
@@ -35,7 +35,8 @@ struct TestLoopZeroTripCheck
     func.walk([&](LoopLikeOpInterface op) {
       auto result = op.replaceWithZeroTripCheck(rewriter);
       if (failed(result)) {
-        // Ignore failures (e.g. not implemented) in tests.
+        // Ignore not implemented failure in tests. The expected output should
+        // catch problems (e.g. transformation doesn't happen).
       }
     });
   }

>From 83c0cd4d1a7372895db52b928065be59f8279a72 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 1 Feb 2024 21:51:09 +0000
Subject: [PATCH 09/10] Implement replaceWithZeroTripCheck for scf.while

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td |   4 +-
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 104 +++++++++++++++++++++
 2 files changed, 107 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..7873020c5a1819 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -939,7 +939,9 @@ def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
-        ["getRegionIterArgs", "getYieldedValuesMutable"]>,
+        ["getRegionIterArgs",
+         "getYieldedValuesMutable",
+         "replaceWithZeroTripCheck"]>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9822ee522c6ed8..58f13204f95cb1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3254,6 +3254,110 @@ LogicalResult scf::WhileOp::verify() {
   return success(afterTerminator != nullptr);
 }
 
+/// Create zero-trip-check for a `while` op. Given an example below:
+///
+///   scf.while (%arg0 = %init) : (i32) -> i64 {
+///     %val = .., %arg0 : i64
+///     %cond = arith.cmpi .., %arg0 : i32
+///     scf.condition(%cond) %val : i64
+///   } do {
+///   ^bb0(%arg1: i64):
+///     %next = .., %arg1 : i32
+///     scf.yield %next : i32
+///   }
+///
+/// First clone before block to the front of the loop:
+///
+///   %val0 = .., %init : i64
+///   %cond0 = arith.cmpi .., %init : i32
+///   scf.while (%arg0 = %init) : (i32) -> i64 {
+///     %val = .., %arg0 : i64
+///     %cond = arith.cmpi .., %arg0 : i32
+///     scf.condition(%cond) %val : i64
+///   } do {
+///   ^bb0(%arg1: i64):
+///     %next = .., %arg1 : i32
+///     scf.yield %next : i32
+///   }
+///
+/// Create `if` op with the condition, rotate and move the loop into the else
+/// branch:
+///
+///   %val0 = .., %init : i64
+///   %cond0 = arith.cmpi .., %init : i32
+///   scf.if %cond0 -> i64 {
+///     %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+///       // Original after block
+///       %next = .., %arg1 : i32
+///       // Original before block
+///       %val = .., %next : i64
+///       %cond = arith.cmpi .., %next : i32
+///       scf.condition(%cond) %val : i64
+///     } do {
+///     ^bb0(%arg2: i64):
+///       %scf.yield %arg2 : i32
+///     }
+///     scf.yield %res : i64
+///   } else {
+///     scf.yield %val0 : i64
+///   }
+FailureOr<LoopLikeOpInterface>
+scf::WhileOp::replaceWithZeroTripCheck(RewriterBase &rewriter) {
+  IRMapping mapper;
+  Block *beforeBlock = this->getBeforeBody();
+  // Clone before block before the loop for zero-trip-check.
+  for (auto [arg, init] :
+       llvm::zip_equal(beforeBlock->getArguments(), this->getInits())) {
+    mapper.map(arg, init);
+  }
+  rewriter.setInsertionPoint(*this);
+  for (auto &op : *beforeBlock) {
+    if (isa<scf::ConditionOp>(op)) {
+      break;
+    }
+    // Safe to clone everything as in a single block all defs have been cloned
+    // and added to mapper in order.
+    rewriter.insert(op.clone(mapper));
+  }
+
+  auto condOp = this->getConditionOp();
+  auto clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
+  auto clonedCondArgs = llvm::map_to_vector(
+      condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
+
+  // Create zero-trip-check and move the while loop in.
+  scf::WhileOp newLoop = nullptr;
+  auto ifOp = rewriter.create<scf::IfOp>(
+      this->getLoc(), clonedCondition,
+      [&](OpBuilder &builder, Location loc) {
+        // Then runs the while loop.
+        newLoop = builder.create<scf::WhileOp>(
+            loc, this->getResultTypes(), clonedCondArgs,
+            [&](OpBuilder &builder, Location loc, ValueRange args) {
+              // Rotate and move the loop body into before block.
+              auto newBlock = builder.getBlock();
+              rewriter.mergeBlocks(this->getAfterBody(), newBlock, args);
+              auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
+              rewriter.mergeBlocks(this->getBeforeBody(), newBlock,
+                                   yieldOp.getResults());
+              rewriter.eraseOp(yieldOp);
+            },
+            [&](OpBuilder &builder, Location loc, ValueRange args) {
+              // Pass-through values in after block.
+              builder.create<scf::YieldOp>(loc, args);
+            });
+        builder.create<scf::YieldOp>(loc, newLoop.getResults());
+      },
+      [&](OpBuilder &builder, Location loc) {
+        // Else returns the results from zero-trip-check.
+        builder.create<scf::YieldOp>(loc, clonedCondArgs);
+      });
+
+  rewriter.replaceOp(*this, ifOp);
+
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 namespace {
 /// Replace uses of the condition within the do block with true, since otherwise
 /// the block would not be evaluated.

>From 3c2dd2b4e9eab8ac2400bc932855b5358b77e664 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 1 Feb 2024 23:13:37 +0000
Subject: [PATCH 10/10] Add tests

---
 .../SCF/while-loop-zero-trip-check.mlir       | 40 +++++++++++++
 mlir/test/lib/Dialect/SCF/CMakeLists.txt      |  1 +
 .../lib/Dialect/SCF/TestLoopZeroTripCheck.cpp | 59 +++++++++++++++++++
 3 files changed, 100 insertions(+)
 create mode 100644 mlir/test/Dialect/SCF/while-loop-zero-trip-check.mlir
 create mode 100644 mlir/test/lib/Dialect/SCF/TestLoopZeroTripCheck.cpp

diff --git a/mlir/test/Dialect/SCF/while-loop-zero-trip-check.mlir b/mlir/test/Dialect/SCF/while-loop-zero-trip-check.mlir
new file mode 100644
index 00000000000000..f5f0a55ad4f165
--- /dev/null
+++ b/mlir/test/Dialect/SCF/while-loop-zero-trip-check.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt %s -test-scf-while-zero-trip-check -split-input-file  | FileCheck %s
+
+func.func @replace_scf_while_with_zero_trip_check(%bound : i32) -> i32 {
+  %cst0 = arith.constant 0 : i32
+  %cst5 = arith.constant 5 : i32
+  %res:2 = scf.while (%iter = %cst0) : (i32) -> (i32, i32) {
+    %cond = arith.cmpi slt, %iter, %bound : i32
+    %inv = arith.addi %bound, %cst5 : i32
+    scf.condition(%cond) %iter, %inv : i32, i32
+  } do {
+  ^bb0(%arg1: i32, %arg2: i32):
+    %next = arith.addi %arg1, %arg2 : i32
+    scf.yield %next : i32
+  }
+  return %res#0 : i32
+}
+
+// CHECK-LABEL: func.func @replace_scf_while_with_zero_trip_check(
+// CHECK-SAME:      %[[ARG0:.*]]: i32) -> i32 {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG:     %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG:     %[[PRE_COND:.*]] = arith.cmpi slt, %[[C0]], %[[ARG0]] : i32
+// CHECK-DAG:     %[[PRE_INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
+// CHECK:         %[[IF:.*]]:2 = scf.if %[[PRE_COND]] -> (i32, i32) {
+// CHECK:           %[[WHILE:.*]]:2 = scf.while (
+// CHECK-SAME:          %[[ARG1:.*]] = %[[C0]], %[[ARG2:.*]] = %[[PRE_INV]]
+// CHECK-SAME:      ) : (i32, i32) -> (i32, i32) {
+// CHECK:             %[[NEXT:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : i32
+// CHECK:             %[[COND:.*]] = arith.cmpi slt, %[[NEXT]], %[[ARG0]] : i32
+// CHECK:             %[[INV:.*]] = arith.addi %[[ARG0]], %[[C5]] : i32
+// CHECK:             scf.condition(%[[COND]]) %[[NEXT]], %[[INV]] : i32, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+// CHECK:             scf.yield %[[ARG3]], %[[ARG4]] : i32, i32
+// CHECK:           }
+// CHECK:           scf.yield %[[WHILE]]#0, %[[WHILE]]#1 : i32, i32
+// CHECK:         } else {
+// CHECK:           scf.yield %[[C0]], %[[PRE_INV]] : i32, i32
+// CHECK:         }
+// CHECK:         return %[[IF]]#0 : i32
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 22c2f2388de69b..d704fe6fe81e38 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRSCFTestPasses
   TestLoopParametricTiling.cpp
   TestLoopUnrolling.cpp
+  TestLoopZeroTripCheck.cpp
   TestSCFUtils.cpp
   TestWhileOpBuilder.cpp
 
diff --git a/mlir/test/lib/Dialect/SCF/TestLoopZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestLoopZeroTripCheck.cpp
new file mode 100644
index 00000000000000..8e203b82a7faec
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestLoopZeroTripCheck.cpp
@@ -0,0 +1,59 @@
+//===- TestLoopZeroTripCheck.cpp -- Pass to test replaceWithZeroTripCheck -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the passes to test replaceWithZeroTripCheck for SCF
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestSCFWhileZeroTripCheckPass
+    : public PassWrapper<TestSCFWhileZeroTripCheckPass,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileZeroTripCheckPass)
+
+  StringRef getArgument() const final {
+    return "test-scf-while-zero-trip-check";
+  }
+  StringRef getDescription() const final {
+    return "test replaceWithZeroTripCheck of scf.while";
+  }
+  explicit TestSCFWhileZeroTripCheckPass() = default;
+  TestSCFWhileZeroTripCheckPass(const TestSCFWhileZeroTripCheckPass &pass)
+      : PassWrapper(pass) {}
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    MLIRContext *context = &getContext();
+    IRRewriter rewriter(context);
+    func.walk([&](scf::WhileOp op) {
+      auto result = op.replaceWithZeroTripCheck(rewriter);
+      if (failed(result)) {
+        signalPassFailure();
+      }
+    });
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLoopZeroTripCheckPass() {
+  PassRegistration<TestSCFWhileZeroTripCheckPass>();
+}
+} // namespace test
+} // namespace mlir



More information about the Mlir-commits mailing list