[Mlir-commits] [mlir] [mlir][SCF] Avoid generating unnecessary div/rem operations during coalescing (PR #91562)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 8 23:04:05 PDT 2024


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/91562

When coalescing is some of the loops are unit-trip we can avoid generating div/rem instructions during delinearization. Ideally we could use some thing like `affine.delinearize` to handle this but tthat causes dependence issues.

>From bd77179deefeb4f344d3087f9d16cbedfd092047 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Wed, 8 May 2024 17:30:06 -0700
Subject: [PATCH] [mlir][SCF] Avoid generating unnecessary div/rem operations
 during coalescing.

When coalescing is some of the loops are unit-trip we can avoid
generating div/rem instructions during delinearization. Ideally we
could use some thing like `affine.delinearize` to handle this but
tthat causes dependence issues.
---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 57 +++++++++++---
 .../Dialect/SCF/transform-op-coalesce.mlir    | 77 +++++++++++++++++++
 2 files changed, 124 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 9279081cfd45d..6658cca03eba7 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -544,11 +544,24 @@ static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
 static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
                                        ArrayRef<Value> values) {
   assert(!values.empty() && "unexpected empty list");
-  Value productOf = values.front();
-  for (auto v : values.drop_front()) {
-    productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
+  std::optional<Value> productOf;
+  for (auto v : values) {
+    auto vOne = getConstantIntValue(v);
+    if (vOne && vOne.value() == 1)
+      continue;
+    if (productOf)
+      productOf =
+          rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult();
+    else
+      productOf = v;
   }
-  return productOf;
+  if (!productOf) {
+    productOf = rewriter
+                    .create<arith::ConstantOp>(
+                        loc, rewriter.getOneAttr(values.front().getType()))
+                    .getResult();
+  }
+  return productOf.value();
 }
 
 /// For each original loop, the value of the
@@ -562,19 +575,43 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
 delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
                              Value linearizedIv, ArrayRef<Value> ubs) {
-  Value previous = linearizedIv;
   SmallVector<Value> delinearizedIvs(ubs.size());
   SmallPtrSet<Operation *, 2> preservedUsers;
-  for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
-    unsigned idx = ubs.size() - i - 1;
-    if (i != 0) {
+
+  llvm::BitVector isUbOne(ubs.size());
+  for (auto [index, ub] : llvm::enumerate(ubs)) {
+    auto ubCst = getConstantIntValue(ub);
+    if (ubCst && ubCst.value() == 1)
+      isUbOne.set(index);
+  }
+
+  // Prune the lead ubs that are all ones.
+  unsigned numLeadingOneUbs = 0;
+  for (auto [index, ub] : llvm::enumerate(ubs)) {
+    if (!isUbOne.test(index)) {
+      break;
+    }
+    delinearizedIvs[index] = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(ub.getType()));
+    numLeadingOneUbs++;
+  }
+
+  Value previous = linearizedIv;
+  for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
+    unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
+    if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
       previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
       preservedUsers.insert(previous.getDefiningOp());
     }
     Value iv = previous;
     if (i != e - 1) {
-      iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
-      preservedUsers.insert(iv.getDefiningOp());
+      if (!isUbOne.test(idx)) {
+        iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
+        preservedUsers.insert(iv.getDefiningOp());
+      } else {
+        iv = rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getZeroAttr(ubs[idx].getType()));
+      }
     }
     delinearizedIvs[idx] = iv;
   }
diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
index 4dc3e4ea0ef45..6fcd727621bae 100644
--- a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
+++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
@@ -299,3 +299,80 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-NOT:       scf.for
 //      CHECK:   transform.named_sequence
 
+// -----
+
+// Check avoiding generating unnecessary operations while collapsing trip-1 loops.
+func.func @trip_one_loops(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
+    %1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
+      %2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
+        %3 = scf.for %iv3 = %c0 to %c1 step %c1 iter_args(%iter3 = %iter2) -> tensor<?x?xf32> {
+          %4 = scf.for %iv4 = %c0 to %arg2 step %c1 iter_args(%iter4 = %iter3) -> tensor<?x?xf32> {
+            %5 = "some_use"(%iter4, %iv0, %iv1, %iv2, %iv3, %iv4)
+              : (tensor<?x?xf32>, index, index, index, index, index) -> (tensor<?x?xf32>)
+            scf.yield %5 : tensor<?x?xf32>
+          }
+          scf.yield %4 : tensor<?x?xf32>
+        }
+        scf.yield %3 : tensor<?x?xf32>
+      }
+      scf.yield %2 : tensor<?x?xf32>
+    }
+    scf.yield %1 : tensor<?x?xf32>
+  } {coalesce}
+  return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+    %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @trip_one_loops
+//  CHECK-SAME:     , %[[ARG1:.+]]: index,
+//  CHECK-SAME:     %[[ARG2:.+]]: index)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[UB:.+]] = arith.muli %[[ARG1]], %[[ARG2]]
+//       CHECK:   scf.for %[[IV:.+]] = %[[C0]] to %[[UB]] step %[[C1]]
+//       CHECK:     %[[IV1:.+]] = arith.remsi %[[IV]], %[[ARG2]]
+//       CHECK:     %[[IV2:.+]] = arith.divsi %[[IV]], %[[ARG2]]
+//       CHECK:     "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV2]], %[[C0]], %[[IV1]])
+
+// -----
+
+// Check generating no instructions when all except one loops is non unit-trip.
+func.func @all_outer_trip_one(%arg0 : tensor<?x?xf32>, %arg1 : index) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = scf.for %iv0 = %c0 to %c1 step %c1 iter_args(%iter0 = %arg0) -> tensor<?x?xf32> {
+    %1 = scf.for %iv1 = %c0 to %c1 step %c1 iter_args(%iter1 = %iter0) -> tensor<?x?xf32> {
+      %2 = scf.for %iv2 = %c0 to %arg1 step %c1 iter_args(%iter2 = %iter1) -> tensor<?x?xf32> {
+        %3 = "some_use"(%iter2, %iv0, %iv1, %iv2)
+          : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>)
+        scf.yield %3 : tensor<?x?xf32>
+      }
+      scf.yield %2 : tensor<?x?xf32>
+    }
+    scf.yield %1 : tensor<?x?xf32>
+  } {coalesce}
+  return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+    %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @all_outer_trip_one
+//  CHECK-SAME:     , %[[ARG1:.+]]: index)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   scf.for %[[IV:.+]] = %[[C0]] to %[[ARG1]] step %[[C1]]
+//       CHECK:     "some_use"(%{{[a-zA-Z0-9]+}}, %[[C0]], %[[C0]], %[[IV]])



More information about the Mlir-commits mailing list