[Mlir-commits] [mlir] [mlir][scf] Allow unrolling loops with integer-typed IV. (PR #106164)

Hongtao Yu llvmlistbot at llvm.org
Mon Aug 26 17:53:27 PDT 2024


https://github.com/htyu updated https://github.com/llvm/llvm-project/pull/106164

>From 488b3b62c7607dbb6805b0118b7dc7437d434757 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Mon, 26 Aug 2024 16:28:20 -0700
Subject: [PATCH 1/2] [mlir][scf] Allow unrolling loops with integer-typed IV.

---
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h |  2 ++
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        | 10 ++++++
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 33 ++++++++++---------
 mlir/test/Dialect/SCF/loop-unroll.mlir        | 22 ++++++-------
 4 files changed, 41 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 76f5825025739b..2202a2d62ebd92 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -93,6 +93,8 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                   int64_t value);
 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                   const APFloat &value);
+Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type,
+                               int64_t value);
 
 /// Returns the int type of the integer in ofr.
 /// Other attribute types are not supported.
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index e75db84b75e280..5fb1cda3211828 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -302,6 +302,16 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
   return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
 }
 
+Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type,
+                                     int64_t value) {
+  assert(type.isIntOrIndex() &&
+         "unexpected type other than integers and index");
+  if (type.isIndex())
+    return b.create<arith::ConstantIndexOp>(loc, value);
+  else
+    return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(type, value));
+}
+
 Type mlir::getType(OpFoldResult ofr) {
   if (auto value = dyn_cast_if_present<Value>(ofr))
     return value.getType();
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index ff5e3a002263d3..49bda65bad2df2 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -264,11 +264,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
                              int64_t divisor) {
   assert(divisor > 0 && "expected positive divisor");
-  assert(dividend.getType().isIndex() && "expected index-typed value");
+  assert(dividend.getType().isIntOrIndex() &&
+         "expected integer or index-typed value");
 
   Value divisorMinusOneCst =
-      builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
-  Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
+      createIntOrIndexConstant(builder, loc, dividend.getType(), divisor - 1);
+  Value divisorCst =
+      createIntOrIndexConstant(builder, loc, dividend.getType(), divisor);
   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
   return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
 }
@@ -279,9 +281,9 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
 // where divis is rounding-to-zero division.
 static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
                              Value divisor) {
-  assert(dividend.getType().isIndex() && "expected index-typed value");
-
-  Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
+  assert(dividend.getType().isIntOrIndex() &&
+         "expected integer or index-typed value");
+  Value cstOne = createIntOrIndexConstant(builder, loc, dividend.getType(), 1);
   Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
   Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
@@ -388,16 +390,17 @@ LogicalResult mlir::loopUnrollByFactor(
     // Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
     if (generateEpilogueLoop)
-      upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
-          loc, upperBoundUnrolledCst);
+      upperBoundUnrolled = createIntOrIndexConstant(
+          boundsBuilder, loc, forOp.getUpperBound().getType(), upperBoundUnrolledCst);
     else
       upperBoundUnrolled = forOp.getUpperBound();
 
     // Create constant for 'stepUnrolled'.
-    stepUnrolled = stepCst == stepUnrolledCst
-                       ? step
-                       : boundsBuilder.create<arith::ConstantIndexOp>(
-                             loc, stepUnrolledCst);
+    stepUnrolled =
+        stepCst == stepUnrolledCst
+            ? step
+            : createIntOrIndexConstant(boundsBuilder, loc, step.getType(),
+                                       stepUnrolledCst);
   } else {
     // Dynamic loop bounds computation.
     // TODO: Add dynamic asserts for negative lb/ub/step, or
@@ -407,8 +410,8 @@ LogicalResult mlir::loopUnrollByFactor(
     Value diff =
         boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
     Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
-    Value unrollFactorCst =
-        boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
+    Value unrollFactorCst = createIntOrIndexConstant(
+        boundsBuilder, loc, tripCount.getType(), unrollFactor);
     Value tripCountRem =
         boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
     // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
@@ -455,7 +458,7 @@ LogicalResult mlir::loopUnrollByFactor(
       [&](unsigned i, Value iv, OpBuilder b) {
         // iv' = iv + step * i;
         auto stride = b.create<arith::MulIOp>(
-            loc, step, b.create<arith::ConstantIndexOp>(loc, i));
+            loc, step, createIntOrIndexConstant(b, loc, iv.getType(), i));
         return b.create<arith::AddIOp>(loc, iv, stride);
       },
       annotateFn, iterArgs, yieldedValues);
diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir
index e28efbb6ec2b91..ae994a8e6c5d7b 100644
--- a/mlir/test/Dialect/SCF/loop-unroll.mlir
+++ b/mlir/test/Dialect/SCF/loop-unroll.mlir
@@ -311,10 +311,10 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
 // Test that epilogue's arguments are correctly renamed.
 func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
   %0 = arith.constant 7.0 : f32
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 20 : index
-  %step = arith.constant 1 : index
-  %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) {
+  %lb = arith.constant 0 : i32
+  %ub = arith.constant 20 : i32
+  %step = arith.constant 1 : i32
+  %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{
     %add = arith.addf %arg0, %arg1 : f32
     %mul = arith.mulf %arg0, %arg1 : f32
     scf.yield %add, %mul : f32, f32
@@ -324,13 +324,13 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
 // UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments
 //
 //   UNROLL-BY-3-DAG:   %[[CST:.*]] = arith.constant {{.*}} : f32
-//   UNROLL-BY-3-DAG:   %[[C0:.*]] = arith.constant 0 : index
-//   UNROLL-BY-3-DAG:   %[[C1:.*]] = arith.constant 1 : index
-//   UNROLL-BY-3-DAG:   %[[C20:.*]] = arith.constant 20 : index
-//   UNROLL-BY-3-DAG:   %[[C18:.*]] = arith.constant 18 : index
-//   UNROLL-BY-3-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   UNROLL-BY-3-DAG:   %[[C0:.*]] = arith.constant 0 : i32
+//   UNROLL-BY-3-DAG:   %[[C1:.*]] = arith.constant 1 : i32
+//   UNROLL-BY-3-DAG:   %[[C20:.*]] = arith.constant 20 : i32
+//   UNROLL-BY-3-DAG:   %[[C18:.*]] = arith.constant 18 : i32
+//   UNROLL-BY-3-DAG:   %[[C3:.*]] = arith.constant 3 : i32
 //       UNROLL-BY-3:   %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
-//  UNROLL-BY-3-SAME:     iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) {
+//  UNROLL-BY-3-SAME:     iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32)  : i32 {
 //  UNROLL-BY-3-NEXT:     %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
 //  UNROLL-BY-3-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
 //  UNROLL-BY-3-NEXT:     %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
@@ -340,7 +340,7 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
 //  UNROLL-BY-3-NEXT:     scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
 //  UNROLL-BY-3-NEXT:   }
 //       UNROLL-BY-3:   %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
-//  UNROLL-BY-3-SAME:     iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) {
+//  UNROLL-BY-3-SAME:     iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32)  : i32 {
 //  UNROLL-BY-3-NEXT:     %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
 //  UNROLL-BY-3-NEXT:     %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
 //  UNROLL-BY-3-NEXT:     scf.yield %[[EADD]], %[[EMUL]] : f32, f32

>From 92cf042759136889bd457855130d18cae6011d01 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Mon, 26 Aug 2024 17:50:51 -0700
Subject: [PATCH 2/2] Simplying, formatting and adding doc.

---
 mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 3 +++
 mlir/lib/Dialect/Arith/Utils/Utils.cpp        | 5 +----
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 3 ++-
 3 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 2202a2d62ebd92..34abf215aceb62 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -93,6 +93,9 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                   int64_t value);
 Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
                                   const APFloat &value);
+
+/// Create a constant of type `type` at location `loc` whose value is `value`.
+/// This works for integer type or the index type only.
 Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type,
                                int64_t value);
 
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 5fb1cda3211828..371d325c40d2da 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -306,10 +306,7 @@ Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type,
                                      int64_t value) {
   assert(type.isIntOrIndex() &&
          "unexpected type other than integers and index");
-  if (type.isIndex())
-    return b.create<arith::ConstantIndexOp>(loc, value);
-  else
-    return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(type, value));
+  return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(type, value));
 }
 
 Type mlir::getType(OpFoldResult ofr) {
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 49bda65bad2df2..7e1feced24aa88 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -391,7 +391,8 @@ LogicalResult mlir::loopUnrollByFactor(
     generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
     if (generateEpilogueLoop)
       upperBoundUnrolled = createIntOrIndexConstant(
-          boundsBuilder, loc, forOp.getUpperBound().getType(), upperBoundUnrolledCst);
+          boundsBuilder, loc, forOp.getUpperBound().getType(),
+          upperBoundUnrolledCst);
     else
       upperBoundUnrolled = forOp.getUpperBound();
 



More information about the Mlir-commits mailing list