[Mlir-commits] [mlir] c08c6a7 - [mlir][scf] Allow unrolling loops with integer-typed IV. (#106164)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 29 09:21:03 PDT 2024
Author: Hongtao Yu
Date: 2024-08-29T09:20:59-07:00
New Revision: c08c6a71cfc536e22fb7ad733fb8181a9e84e62a
URL: https://github.com/llvm/llvm-project/commit/c08c6a71cfc536e22fb7ad733fb8181a9e84e62a
DIFF: https://github.com/llvm/llvm-project/commit/c08c6a71cfc536e22fb7ad733fb8181a9e84e62a.diff
LOG: [mlir][scf] Allow unrolling loops with integer-typed IV. (#106164)
SCF loops now can operate on integer-typed IV, thus I'm changing the
loop unroller correspondingly.
Added:
Modified:
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Dialect/SCF/loop-unroll.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 9545610f10be7c..a794a121d6267b 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -270,11 +270,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
int64_t divisor) {
assert(divisor > 0 && "expected positive divisor");
- assert(dividend.getType().isIndex() && "expected index-typed value");
+ assert(dividend.getType().isIntOrIndex() &&
+ "expected integer or index-typed value");
- Value divisorMinusOneCst =
- builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
- Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
+ Value divisorMinusOneCst = builder.create<arith::ConstantOp>(
+ loc, builder.getIntegerAttr(dividend.getType(), divisor - 1));
+ Value divisorCst = builder.create<arith::ConstantOp>(
+ loc, builder.getIntegerAttr(dividend.getType(), divisor));
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
}
@@ -285,9 +287,10 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
// where divis is rounding-to-zero division.
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
Value divisor) {
- assert(dividend.getType().isIndex() && "expected index-typed value");
-
- Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
+ assert(dividend.getType().isIntOrIndex() &&
+ "expected integer or index-typed value");
+ Value cstOne = builder.create<arith::ConstantOp>(
+ loc, builder.getOneAttr(dividend.getType()));
Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
return builder.create<arith::DivUIOp>(loc, sum, divisor);
@@ -409,16 +412,18 @@ LogicalResult mlir::loopUnrollByFactor(
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
if (generateEpilogueLoop)
- upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
- loc, upperBoundUnrolledCst);
+ upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>(
+ loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(),
+ upperBoundUnrolledCst));
else
upperBoundUnrolled = forOp.getUpperBound();
// Create constant for 'stepUnrolled'.
stepUnrolled = stepCst == stepUnrolledCst
? step
- : boundsBuilder.create<arith::ConstantIndexOp>(
- loc, stepUnrolledCst);
+ : boundsBuilder.create<arith::ConstantOp>(
+ loc, boundsBuilder.getIntegerAttr(
+ step.getType(), stepUnrolledCst));
} else {
// Dynamic loop bounds computation.
// TODO: Add dynamic asserts for negative lb/ub/step, or
@@ -428,8 +433,8 @@ LogicalResult mlir::loopUnrollByFactor(
Value
diff =
boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
Value tripCount = ceilDivPositive(boundsBuilder, loc,
diff , step);
- Value unrollFactorCst =
- boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
+ Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>(
+ loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor));
Value tripCountRem =
boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
// Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
@@ -476,7 +481,9 @@ LogicalResult mlir::loopUnrollByFactor(
[&](unsigned i, Value iv, OpBuilder b) {
// iv' = iv + step * i;
auto stride = b.create<arith::MulIOp>(
- loc, step, b.create<arith::ConstantIndexOp>(loc, i));
+ loc, step,
+ b.create<arith::ConstantOp>(loc,
+ b.getIntegerAttr(iv.getType(), i)));
return b.create<arith::AddIOp>(loc, iv, stride);
},
annotateFn, iterArgs, yieldedValues);
diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir
index e28efbb6ec2b91..68a11fb6a72c64 100644
--- a/mlir/test/Dialect/SCF/loop-unroll.mlir
+++ b/mlir/test/Dialect/SCF/loop-unroll.mlir
@@ -448,3 +448,44 @@ func.func @loop_unroll_yield_iter_arg() {
// CHECK-NEXT: affine.yield %[[ITER_ARG]] : index
// CHECK-NEXT: }
// CHECK-NEXT: return
+
+// -----
+
+// Test the loop unroller works with integer IV type.
+func.func @static_loop_unroll_with_integer_iv() -> (f32, f32) {
+ %0 = arith.constant 7.0 : f32
+ %lb = arith.constant 0 : i32
+ %ub = arith.constant 20 : i32
+ %step = arith.constant 1 : i32
+ %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{
+ %add = arith.addf %arg0, %arg1 : f32
+ %mul = arith.mulf %arg0, %arg1 : f32
+ scf.yield %add, %mul : f32, f32
+ }
+ return %result#0, %result#1 : f32, f32
+}
+// UNROLL-BY-3-LABEL: func @static_loop_unroll_with_integer_iv
+//
+// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
+// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32
+// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32
+// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32
+// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
+// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 {
+// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
+// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
+// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
+// UNROLL-BY-3-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD0]], %[[MUL0]] : f32
+// UNROLL-BY-3-NEXT: %[[ADD2:.*]] = arith.addf %[[ADD1]], %[[MUL1]] : f32
+// UNROLL-BY-3-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD1]], %[[MUL1]] : f32
+// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
+// UNROLL-BY-3-NEXT: }
+// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
+// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 {
+// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
+// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
+// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32
+// UNROLL-BY-3-NEXT: }
+// UNROLL-BY-3-NEXT: return %[[EFOR]]#0, %[[EFOR]]#1 : f32, f32
More information about the Mlir-commits
mailing list