[Mlir-commits] [mlir] [mlir][scf] Allow unrolling loops with integer-typed IV. (PR #106164)
Hongtao Yu
llvmlistbot at llvm.org
Mon Aug 26 17:53:27 PDT 2024
https://github.com/htyu updated https://github.com/llvm/llvm-project/pull/106164
>From 488b3b62c7607dbb6805b0118b7dc7437d434757 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Mon, 26 Aug 2024 16:28:20 -0700
Subject: [PATCH 1/2] [mlir][scf] Allow unrolling loops with integer-typed IV.
---
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 2 ++
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 10 ++++++
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 33 ++++++++++---------
mlir/test/Dialect/SCF/loop-unroll.mlir | 22 ++++++-------
4 files changed, 41 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 76f5825025739b..2202a2d62ebd92 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -93,6 +93,8 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
int64_t value);
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
const APFloat &value);
+Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type,
+ int64_t value);
/// Returns the int type of the integer in ofr.
/// Other attribute types are not supported.
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index e75db84b75e280..5fb1cda3211828 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -302,6 +302,16 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc,
return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
}
+Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type,
+ int64_t value) {
+ assert(type.isIntOrIndex() &&
+ "unexpected type other than integers and index");
+ if (type.isIndex())
+ return b.create<arith::ConstantIndexOp>(loc, value);
+ else
+ return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(type, value));
+}
+
Type mlir::getType(OpFoldResult ofr) {
if (auto value = dyn_cast_if_present<Value>(ofr))
return value.getType();
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index ff5e3a002263d3..49bda65bad2df2 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -264,11 +264,13 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
int64_t divisor) {
assert(divisor > 0 && "expected positive divisor");
- assert(dividend.getType().isIndex() && "expected index-typed value");
+ assert(dividend.getType().isIntOrIndex() &&
+ "expected integer or index-typed value");
Value divisorMinusOneCst =
- builder.create<arith::ConstantIndexOp>(loc, divisor - 1);
- Value divisorCst = builder.create<arith::ConstantIndexOp>(loc, divisor);
+ createIntOrIndexConstant(builder, loc, dividend.getType(), divisor - 1);
+ Value divisorCst =
+ createIntOrIndexConstant(builder, loc, dividend.getType(), divisor);
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
return builder.create<arith::DivUIOp>(loc, sum, divisorCst);
}
@@ -279,9 +281,9 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
// where divis is rounding-to-zero division.
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
Value divisor) {
- assert(dividend.getType().isIndex() && "expected index-typed value");
-
- Value cstOne = builder.create<arith::ConstantIndexOp>(loc, 1);
+ assert(dividend.getType().isIntOrIndex() &&
+ "expected integer or index-typed value");
+ Value cstOne = createIntOrIndexConstant(builder, loc, dividend.getType(), 1);
Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne);
Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne);
return builder.create<arith::DivUIOp>(loc, sum, divisor);
@@ -388,16 +390,17 @@ LogicalResult mlir::loopUnrollByFactor(
// Create constant for 'upperBoundUnrolled' and set epilogue loop flag.
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
if (generateEpilogueLoop)
- upperBoundUnrolled = boundsBuilder.create<arith::ConstantIndexOp>(
- loc, upperBoundUnrolledCst);
+ upperBoundUnrolled = createIntOrIndexConstant(
+ boundsBuilder, loc, forOp.getUpperBound().getType(), upperBoundUnrolledCst);
else
upperBoundUnrolled = forOp.getUpperBound();
// Create constant for 'stepUnrolled'.
- stepUnrolled = stepCst == stepUnrolledCst
- ? step
- : boundsBuilder.create<arith::ConstantIndexOp>(
- loc, stepUnrolledCst);
+ stepUnrolled =
+ stepCst == stepUnrolledCst
+ ? step
+ : createIntOrIndexConstant(boundsBuilder, loc, step.getType(),
+ stepUnrolledCst);
} else {
// Dynamic loop bounds computation.
// TODO: Add dynamic asserts for negative lb/ub/step, or
@@ -407,8 +410,8 @@ LogicalResult mlir::loopUnrollByFactor(
Value diff =
boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step);
- Value unrollFactorCst =
- boundsBuilder.create<arith::ConstantIndexOp>(loc, unrollFactor);
+ Value unrollFactorCst = createIntOrIndexConstant(
+ boundsBuilder, loc, tripCount.getType(), unrollFactor);
Value tripCountRem =
boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
// Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor)
@@ -455,7 +458,7 @@ LogicalResult mlir::loopUnrollByFactor(
[&](unsigned i, Value iv, OpBuilder b) {
// iv' = iv + step * i;
auto stride = b.create<arith::MulIOp>(
- loc, step, b.create<arith::ConstantIndexOp>(loc, i));
+ loc, step, createIntOrIndexConstant(b, loc, iv.getType(), i));
return b.create<arith::AddIOp>(loc, iv, stride);
},
annotateFn, iterArgs, yieldedValues);
diff --git a/mlir/test/Dialect/SCF/loop-unroll.mlir b/mlir/test/Dialect/SCF/loop-unroll.mlir
index e28efbb6ec2b91..ae994a8e6c5d7b 100644
--- a/mlir/test/Dialect/SCF/loop-unroll.mlir
+++ b/mlir/test/Dialect/SCF/loop-unroll.mlir
@@ -311,10 +311,10 @@ func.func @static_loop_unroll_up_to_factor(%arg0 : memref<?xf32>) {
// Test that epilogue's arguments are correctly renamed.
func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
%0 = arith.constant 7.0 : f32
- %lb = arith.constant 0 : index
- %ub = arith.constant 20 : index
- %step = arith.constant 1 : index
- %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) {
+ %lb = arith.constant 0 : i32
+ %ub = arith.constant 20 : i32
+ %step = arith.constant 1 : i32
+ %result:2 = scf.for %i0 = %lb to %ub step %step iter_args(%arg0 = %0, %arg1 = %0) -> (f32, f32) : i32{
%add = arith.addf %arg0, %arg1 : f32
%mul = arith.mulf %arg0, %arg1 : f32
scf.yield %add, %mul : f32, f32
@@ -324,13 +324,13 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
// UNROLL-BY-3-LABEL: func @static_loop_unroll_by_3_rename_epilogue_arguments
//
// UNROLL-BY-3-DAG: %[[CST:.*]] = arith.constant {{.*}} : f32
-// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : index
-// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : index
-// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : index
-// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : index
-// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : index
+// UNROLL-BY-3-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// UNROLL-BY-3-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// UNROLL-BY-3-DAG: %[[C20:.*]] = arith.constant 20 : i32
+// UNROLL-BY-3-DAG: %[[C18:.*]] = arith.constant 18 : i32
+// UNROLL-BY-3-DAG: %[[C3:.*]] = arith.constant 3 : i32
// UNROLL-BY-3: %[[FOR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C18]] step %[[C3]]
-// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) {
+// UNROLL-BY-3-SAME: iter_args(%[[ARG0:.*]] = %[[CST]], %[[ARG1:.*]] = %[[CST]]) -> (f32, f32) : i32 {
// UNROLL-BY-3-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[MUL0:.*]] = arith.mulf %[[ARG0]], %[[ARG1]] : f32
// UNROLL-BY-3-NEXT: %[[ADD1:.*]] = arith.addf %[[ADD0]], %[[MUL0]] : f32
@@ -340,7 +340,7 @@ func.func @static_loop_unroll_by_3_rename_epilogue_arguments() -> (f32, f32) {
// UNROLL-BY-3-NEXT: scf.yield %[[ADD2]], %[[MUL2]] : f32, f32
// UNROLL-BY-3-NEXT: }
// UNROLL-BY-3: %[[EFOR:.*]]:2 = scf.for %[[EIV:.*]] = %[[C18]] to %[[C20]] step %[[C1]]
-// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) {
+// UNROLL-BY-3-SAME: iter_args(%[[EARG0:.*]] = %[[FOR]]#0, %[[EARG1:.*]] = %[[FOR]]#1) -> (f32, f32) : i32 {
// UNROLL-BY-3-NEXT: %[[EADD:.*]] = arith.addf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: %[[EMUL:.*]] = arith.mulf %[[EARG0]], %[[EARG1]] : f32
// UNROLL-BY-3-NEXT: scf.yield %[[EADD]], %[[EMUL]] : f32, f32
>From 92cf042759136889bd457855130d18cae6011d01 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Mon, 26 Aug 2024 17:50:51 -0700
Subject: [PATCH 2/2] Simplying, formatting and adding doc.
---
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 3 +++
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 5 +----
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 3 ++-
3 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 2202a2d62ebd92..34abf215aceb62 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -93,6 +93,9 @@ Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
int64_t value);
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type,
const APFloat &value);
+
+/// Create a constant of type `type` at location `loc` whose value is `value`.
+/// This works for integer type or the index type only.
Value createIntOrIndexConstant(OpBuilder &builder, Location loc, Type type,
int64_t value);
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 5fb1cda3211828..371d325c40d2da 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -306,10 +306,7 @@ Value mlir::createIntOrIndexConstant(OpBuilder &b, Location loc, Type type,
int64_t value) {
assert(type.isIntOrIndex() &&
"unexpected type other than integers and index");
- if (type.isIndex())
- return b.create<arith::ConstantIndexOp>(loc, value);
- else
- return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(type, value));
+ return b.create<arith::ConstantOp>(loc, b.getIntegerAttr(type, value));
}
Type mlir::getType(OpFoldResult ofr) {
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 49bda65bad2df2..7e1feced24aa88 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -391,7 +391,8 @@ LogicalResult mlir::loopUnrollByFactor(
generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
if (generateEpilogueLoop)
upperBoundUnrolled = createIntOrIndexConstant(
- boundsBuilder, loc, forOp.getUpperBound().getType(), upperBoundUnrolledCst);
+ boundsBuilder, loc, forOp.getUpperBound().getType(),
+ upperBoundUnrolledCst);
else
upperBoundUnrolled = forOp.getUpperBound();
More information about the Mlir-commits
mailing list