[Mlir-commits] [mlir] [mlir][scf] `scf.while` uplifting: optimize op matching (PR #88813)

Ivan Butygin llvmlistbot at llvm.org
Mon Apr 15 15:57:28 PDT 2024


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/88813

Instead of iterating over potential induction var uses looking for suitable `arith.addi`, try to trace it back from yield argument.

>From ec26daea68f6545e93d79e33d4860cd5c0733406 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 16 Apr 2024 00:50:42 +0200
Subject: [PATCH 1/2] [mlir][scf] `scf.while` uplifting: optimize op matching

Instead of iterating over potential induction var uses looking for suitable `arith.addi`, try to trace it back from yield argument.
---
 .../SCF/Transforms/UpliftWhileToFor.cpp       | 37 +++++++------------
 1 file changed, 14 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index fea2f659535bb4..461957f99ab669 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -101,38 +101,29 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
 
   Block *afterBody = loop.getAfterBody();
   scf::YieldOp afterTerm = loop.getYieldOp();
-  auto argNumber = inductionVar.getArgNumber();
-  auto afterTermIndArg = afterTerm.getResults()[argNumber];
+  unsigned argNumber = inductionVar.getArgNumber();
+  Value afterTermIndArg = afterTerm.getResults()[argNumber];
 
-  auto inductionVarAfter = afterBody->getArgument(argNumber);
-
-  Value step;
+  Value inductionVarAfter = afterBody->getArgument(argNumber);
 
   // Find suitable `addi` op inside `after` block, one of the args must be an
   // Induction var passed from `before` block and second arg must be defined
   // outside of the loop and will be considered step value.
   // TODO: Add `subi` support?
-  for (auto &use : inductionVarAfter.getUses()) {
-    auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
-    if (!owner)
-      continue;
-
-    auto other =
-        (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
-    if (!dom.properlyDominates(other, loop))
-      continue;
-
-    if (afterTermIndArg != owner.getResult())
-      continue;
+  auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
+  if (!addOp)
+    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
 
-    step = other;
-    break;
+  Value step;
+  if (addOp.getLhs() == inductionVarAfter) {
+    step = addOp.getRhs();
+  } else if (addOp.getRhs() == inductionVarAfter) {
+    step = addOp.getLhs();
+  } else {
+    return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
   }
 
-  if (!step)
-    return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
-
-  auto lb = loop.getInits()[argNumber];
+  Value lb = loop.getInits()[argNumber];
 
   assert(lb.getType().isIntOrIndex());
   assert(lb.getType() == ub.getType());

>From a778bbb6aae04072b28d1c1d3c7afe26ebb4e6b3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 16 Apr 2024 00:55:41 +0200
Subject: [PATCH 2/2] fix

---
 mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 461957f99ab669..7b4024b6861a72 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -119,10 +119,11 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
     step = addOp.getRhs();
   } else if (addOp.getRhs() == inductionVarAfter) {
     step = addOp.getLhs();
-  } else {
-    return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
   }
 
+  if (!step || !dom.properlyDominates(step, loop))
+    return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
+
   Value lb = loop.getInits()[argNumber];
 
   assert(lb.getType().isIntOrIndex());



More information about the Mlir-commits mailing list