[Mlir-commits] [mlir] 1562161 - [mlir] IntRangeNarrowing: Narrow loop induction variables. (#175455)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 13 02:42:57 PST 2026
Author: Ivan Butygin
Date: 2026-01-13T10:42:52Z
New Revision: 1562161cd20cf13386b03556ba2c3deafa26f4fe
URL: https://github.com/llvm/llvm-project/commit/1562161cd20cf13386b03556ba2c3deafa26f4fe
DIFF: https://github.com/llvm/llvm-project/commit/1562161cd20cf13386b03556ba2c3deafa26f4fe.diff
LOG: [mlir] IntRangeNarrowing: Narrow loop induction variables. (#175455)
There are 2 parts:
* Update `LoopLikeOpInterface` to check the supported induction var type
and to update the loop bounds.
* Implement `NarrowLoopBounds` pattern which tries to narrow loop
induction var and bounds using this new interface.
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/include/mlir/Interfaces/LoopLikeInterface.td
mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Interfaces/LoopLikeInterface.cpp
mlir/test/Dialect/Arith/int-range-narrowing.mlir
mlir/test/Dialect/SCF/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index b03cf2db78041..18ac0dbc8d13e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -87,6 +87,12 @@ void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver,
ArrayRef<unsigned> bitwidthsSupported);
+/// Add patterns for narrowing control flow values (loop bounds, steps, etc.)
+/// based on int range analysis.
+void populateControlFlowValuesNarrowingPatterns(
+ RewritePatternSet &patterns, DataFlowSolver &solver,
+ ArrayRef<unsigned> bitwidthsSupported);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8bdf3e0b566ef..7c26d78683de9 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -154,7 +154,8 @@ def ForOp : SCF_Op<"for",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getStaticTripCount", "getYieldedValuesMutable",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
- "yieldTiledValuesAndReplace"]>,
+ "yieldTiledValuesAndReplace", "isValidInductionVarType",
+ "setLoopLowerBounds", "setLoopUpperBounds", "setLoopSteps"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index e09b8672f2d08..89526e92a4c92 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -243,6 +243,54 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*defaultImplementation=*/[{
return ::std::nullopt;
}]
+ >,
+ InterfaceMethod<[{
+ Check if the given type is supported as a loop induction variable.
+ By default, only IndexType is supported.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isValidInductionVarType",
+ /*args=*/(ins "::mlir::Type":$type),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return type.isIndex();
+ }]
+ >,
+ InterfaceMethod<[{
+ Update the loop lower bounds. Returns failure if the loop doesn't
+ support updating bounds or if the number of bounds doesn't match.
+ }],
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"setLoopLowerBounds",
+ /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$bounds),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
+ >,
+ InterfaceMethod<[{
+ Update the loop upper bounds. Returns failure if the loop doesn't
+ support updating bounds or if the number of bounds doesn't match.
+ }],
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"setLoopUpperBounds",
+ /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$bounds),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
+ >,
+ InterfaceMethod<[{
+ Update the loop steps. Returns failure if the loop doesn't support
+ updating steps or if the number of steps doesn't match.
+ }],
+ /*retTy=*/"::mlir::LogicalResult",
+ /*methodName=*/"setLoopSteps",
+ /*args=*/(ins "::llvm::ArrayRef<::mlir::OpFoldResult>":$steps),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return ::mlir::failure();
+ }]
>
];
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 4780dbb4338bb..637e16a3963d6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRFuncTransforms
MLIRInferIntRangeInterface
MLIRIR
+ MLIRLoopLikeInterface
MLIRMemRefDialect
MLIRShardDialect
MLIRPass
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index a85d3bb43a372..fefbba989b996 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -50,8 +51,6 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
Value newVal) {
- assert(oldVal.getType() == newVal.getType() &&
- "Can't copy integer ranges between
diff erent types");
auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
if (!oldState)
return;
@@ -479,6 +478,187 @@ struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
SmallVector<unsigned, 4> targetBitwidths;
};
+struct NarrowLoopBounds final : OpInterfaceRewritePattern<LoopLikeOpInterface> {
+ NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s,
+ ArrayRef<unsigned> target)
+ : OpInterfaceRewritePattern<LoopLikeOpInterface>(context), solver(s),
+ targetBitwidths(target),
+ boundsNarrowingFailedAttr(
+ StringAttr::get(context, "arith.bounds_narrowing_failed")) {}
+
+ LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike,
+ PatternRewriter &rewriter) const override {
+ // Skip ops where bounds narrowing previously failed.
+ if (loopLike->hasAttr(boundsNarrowingFailedAttr))
+ return rewriter.notifyMatchFailure(loopLike,
+ "bounds narrowing previously failed");
+
+ std::optional<SmallVector<Value>> inductionVars =
+ loopLike.getLoopInductionVars();
+ if (!inductionVars.has_value() || inductionVars->empty())
+ return rewriter.notifyMatchFailure(loopLike, "no induction variables");
+
+ std::optional<SmallVector<OpFoldResult>> lowerBounds =
+ loopLike.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> upperBounds =
+ loopLike.getLoopUpperBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loopLike.getLoopSteps();
+
+ if (!lowerBounds.has_value() || !upperBounds.has_value() ||
+ !steps.has_value())
+ return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps");
+
+ if (lowerBounds->size() != inductionVars->size() ||
+ upperBounds->size() != inductionVars->size() ||
+ steps->size() != inductionVars->size())
+ return rewriter.notifyMatchFailure(loopLike,
+ "mismatched bounds/steps count");
+
+ Location loc = loopLike->getLoc();
+ SmallVector<OpFoldResult> newLowerBounds(*lowerBounds);
+ SmallVector<OpFoldResult> newUpperBounds(*upperBounds);
+ SmallVector<OpFoldResult> newSteps(*steps);
+ SmallVector<std::tuple<size_t, Type, CastKind>> narrowings;
+
+ // Check each (indVar, lb, ub, step) tuple.
+ for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] :
+ llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) {
+
+ // Only process value operands, skip attributes.
+ auto maybeLb = dyn_cast<Value>(lbOFR);
+ auto maybeUb = dyn_cast<Value>(ubOFR);
+ auto maybeStep = dyn_cast<Value>(stepOFR);
+
+ if (!maybeLb || !maybeUb || !maybeStep)
+ continue;
+
+ // Collect ranges for (lb, ub, step, indVar).
+ SmallVector<ConstantIntRanges> ranges;
+ if (failed(collectRanges(
+ solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges)))
+ continue;
+
+ const ConstantIntRanges &stepRange = ranges[2];
+ const ConstantIntRanges &indVarRange = ranges[3];
+
+ Type srcType = maybeLb.getType();
+
+ // Try each target bitwidth.
+ for (unsigned targetBitwidth : targetBitwidths) {
+ Type targetType = getTargetType(srcType, targetBitwidth);
+ if (targetType == srcType)
+ continue;
+
+ // Check if the target type is valid for this loop's induction
+ // variables.
+ if (!loopLike.isValidInductionVarType(targetType))
+ continue;
+
+ // Check if all values in this tuple can be truncated.
+ CastKind castKind = CastKind::Both;
+ for (const ConstantIntRanges &range : ranges) {
+ castKind = mergeCastKinds(castKind,
+ checkTruncatability(range, targetBitwidth));
+ if (castKind == CastKind::None)
+ break;
+ }
+
+ if (castKind == CastKind::None)
+ continue;
+
+ // Check if indVar + step fits in the narrowed type.
+ // This is critical for loop correctness: the loop computes
+ // iv_next = iv_current + step in the narrowed type, then compares
+ // iv_next < ub. If iv_current + step overflows, the comparison may
+ // produce incorrect results and break loop termination.
+ // Both signed and unsigned interpretations must fit because loop
+ // semantics are unknown (integer types are signless).
+ ConstantIntRanges indVarPlusStepRange(
+ indVarRange.smin().sadd_sat(stepRange.smin()),
+ indVarRange.smax().sadd_sat(stepRange.smax()),
+ indVarRange.umin().uadd_sat(stepRange.umin()),
+ indVarRange.umax().uadd_sat(stepRange.umax()));
+
+ if (checkTruncatability(indVarPlusStepRange, targetBitwidth) !=
+ CastKind::Both)
+ continue;
+
+ // Narrow the bounds and step values.
+ Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind);
+ Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind);
+ Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind);
+
+ newLowerBounds[idx] = newLb;
+ newUpperBounds[idx] = newUb;
+ newSteps[idx] = newStep;
+ narrowings.push_back({idx, targetType, castKind});
+ break;
+ }
+ }
+
+ if (narrowings.empty())
+ return rewriter.notifyMatchFailure(loopLike, "no narrowings found");
+
+ // Save original types before modifying.
+ SmallVector<Type> origTypes;
+ for (auto [idx, targetType, castKind] : narrowings) {
+ Value indVar = (*inductionVars)[idx];
+ origTypes.push_back(indVar.getType());
+ }
+
+ // Attempt to update bounds and induction variable types.
+ // If this fails, mark the op so we don't try again.
+ bool updateFailed = false;
+ rewriter.modifyOpInPlace(loopLike, [&]() {
+ // Update the loop bounds and steps.
+ if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) ||
+ failed(loopLike.setLoopUpperBounds(newUpperBounds)) ||
+ failed(loopLike.setLoopSteps(newSteps))) {
+ // Mark op to prevent future attempts. IR was modified (attribute
+ // added), so we must return success() from the pattern.
+ loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.getUnitAttr());
+ updateFailed = true;
+ return;
+ }
+
+ // Update induction variable types.
+ for (auto [idx, targetType, castKind] : narrowings) {
+ Value indVar = (*inductionVars)[idx];
+ auto blockArg = cast<BlockArgument>(indVar);
+
+ // Change the block argument type.
+ blockArg.setType(targetType);
+ }
+ });
+
+ if (updateFailed)
+ return success();
+
+ // Insert casts back to original type for uses.
+ for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) {
+ auto [idx, targetType, castKind] = narrowingInfo;
+ Value indVar = (*inductionVars)[idx];
+ auto blockArg = cast<BlockArgument>(indVar);
+ Type origType = origTypes[narrowingIdx];
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(blockArg.getOwner());
+ Value casted = doCast(rewriter, loc, blockArg, origType, castKind);
+ copyIntegerRange(solver, blockArg, casted);
+
+ // Replace all uses of the narrowed indVar with the casted value.
+ rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp());
+ }
+
+ return success();
+ }
+
+private:
+ DataFlowSolver &solver;
+ SmallVector<unsigned, 4> targetBitwidths;
+ StringAttr boundsNarrowingFailedAttr;
+};
+
struct IntRangeOptimizationsPass final
: arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
@@ -520,6 +700,8 @@ struct IntRangeNarrowingPass final
RewritePatternSet patterns(ctx);
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
+ populateControlFlowValuesNarrowingPatterns(patterns, solver,
+ bitwidthsSupported);
// We specifically need bottom-up traversal as cmpi pattern needs range
// data, attached to its original argument values.
@@ -548,6 +730,13 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
bitwidthsSupported);
}
+void mlir::arith::populateControlFlowValuesNarrowingPatterns(
+ RewritePatternSet &patterns, DataFlowSolver &solver,
+ ArrayRef<unsigned> bitwidthsSupported) {
+ patterns.add<NarrowLoopBounds>(patterns.getContext(), solver,
+ bitwidthsSupported);
+}
+
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
return std::make_unique<IntRangeOptimizationsPass>();
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 95a854b655a53..81a167d3514a3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -429,6 +429,40 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
return SmallVector<OpFoldResult>{OpFoldResult(getUpperBound())};
}
+bool ForOp::isValidInductionVarType(Type type) {
+ return type.isIndex() || type.isSignlessInteger();
+}
+
+LogicalResult ForOp::setLoopLowerBounds(ArrayRef<OpFoldResult> bounds) {
+ if (bounds.size() != 1)
+ return failure();
+ if (auto val = dyn_cast<Value>(bounds[0])) {
+ setLowerBound(val);
+ return success();
+ }
+ return failure();
+}
+
+LogicalResult ForOp::setLoopUpperBounds(ArrayRef<OpFoldResult> bounds) {
+ if (bounds.size() != 1)
+ return failure();
+ if (auto val = dyn_cast<Value>(bounds[0])) {
+ setUpperBound(val);
+ return success();
+ }
+ return failure();
+}
+
+LogicalResult ForOp::setLoopSteps(ArrayRef<OpFoldResult> steps) {
+ if (steps.size() != 1)
+ return failure();
+ if (auto val = dyn_cast<Value>(steps[0])) {
+ setStep(val);
+ return success();
+ }
+ return failure();
+}
+
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
/// Promotes the loop body of a forOp to its containing block if the forOp
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index d4cef29008c2a..ae2ac13187dff 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -110,5 +110,16 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
++i;
}
+ // Verify that all induction variables have valid types.
+ auto inductionVars = loopLikeOp.getLoopInductionVars();
+ if (inductionVars.has_value()) {
+ for (auto [index, inductionVar] : llvm::enumerate(*inductionVars)) {
+ if (!loopLikeOp.isValidInductionVarType(inductionVar.getType()))
+ return op->emitOpError(std::to_string(index))
+ << "-th induction variable has invalid type: "
+ << inductionVar.getType();
+ }
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 42dd4294b86e9..9107bf649b561 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -314,21 +314,28 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
// Motivating example for negative number support, added as a test case
// and simplified
// CHECK-LABEL: func.func @clamp_to_loop_bound_and_id()
+// CHECK-DAG: %[[C64_I8:.+]] = arith.constant 64 : i8
+// CHECK-DAG: %[[C64_I16:.+]] = arith.constant 64 : i16
+// CHECK-DAG: %[[C16_I16:.+]] = arith.constant 16 : i16
+// CHECK-DAG: %[[C0_I16:.+]] = arith.constant 0 : i16
// CHECK: %[[TID:.+]] = test.with_bounds
// CHECK-SAME: umax = 63
// CHECK: %[[BOUND:.+]] = test.with_bounds
// CHECK-SAME: umax = 112
-// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND]] step %{{.*}}
-// CHECK-DAG: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
-// CHECK-DAG: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0]] : index to i8
-// CHECK: %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
-// CHECK: %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %{{.*}} : i8
-// CHECK: %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
-// CHECK: %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
-// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
-// CHECK: %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
-// CHECK: %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %{{.*}} : i16
-// CHECK: scf.if %[[V3]]
+// Loop narrows to i16 (not i8) because indVar+step=[80,144] doesn't fit in signed i8.
+// CHECK: %[[BOUND_I16:.+]] = arith.index_castui %[[BOUND]] : index to i16
+// CHECK: scf.for %[[ARG0:.+]] = %[[C16_I16]] to %[[BOUND_I16]] step %[[C64_I16]] : i16 {
+// CHECK: %[[ARG0_INDEX:.+]] = arith.index_castui %[[ARG0]] : i16 to index
+// CHECK: %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
+// CHECK: %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0_INDEX]] : index to i8
+// CHECK: %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
+// CHECK: %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %[[C64_I8]] : i8
+// CHECK: %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
+// CHECK: %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+// CHECK: %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+// CHECK: %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
+// CHECK: %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %[[C0_I16]] : i16
+// CHECK: scf.if %[[V3]]
func.func @clamp_to_loop_bound_and_id() {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
@@ -349,6 +356,7 @@ func.func @clamp_to_loop_bound_and_id() {
}
func.func private @use(index)
+func.func private @use_i64(i64)
// CHECK-LABEL: func.func @loop_with_iter_arg
func.func @loop_with_iter_arg() {
@@ -377,3 +385,35 @@ func.func @loop_with_iter_arg() {
}
return
}
+
+//===----------------------------------------------------------------------===//
+// Loop bounds narrowing
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @narrow_loop_bounds
+func.func @narrow_loop_bounds() {
+ %c0_i64 = arith.constant 0 : i64
+ %c10_i64 = arith.constant 10 : i64
+ %c1_i64 = arith.constant 1 : i64
+
+ // CHECK-DAG: %[[C1_I8:.*]] = arith.constant 1 : i8
+ // CHECK-DAG: %[[LB:.*]] = test.with_bounds {smax = 0 : i64, smin = 0 : i64, umax = 0 : i64, umin = 0 : i64} : i64
+ // CHECK-DAG: %[[UB:.*]] = test.with_bounds {smax = 10 : i64, smin = 10 : i64, umax = 10 : i64, umin = 10 : i64} : i64
+ // CHECK-DAG: %[[STEP:.*]] = test.with_bounds {smax = 1 : i64, smin = 1 : i64, umax = 1 : i64, umin = 1 : i64} : i64
+ %lb = test.with_bounds {smin = 0 : i64, smax = 0 : i64, umin = 0 : i64, umax = 0 : i64} : i64
+ %ub = test.with_bounds {smin = 10 : i64, smax = 10 : i64, umin = 10 : i64, umax = 10 : i64} : i64
+ %step = test.with_bounds {smin = 1 : i64, smax = 1 : i64, umin = 1 : i64, umax = 1 : i64} : i64
+
+ // CHECK-DAG: %[[LB_I8:.*]] = arith.trunci %[[LB]] : i64 to i8
+ // CHECK-DAG: %[[UB_I8:.*]] = arith.trunci %[[UB]] : i64 to i8
+ // CHECK-DAG: %[[STEP_I8:.*]] = arith.trunci %[[STEP]] : i64 to i8
+ // CHECK: scf.for %[[IV:.*]] = %[[LB_I8]] to %[[UB_I8]] step %[[STEP_I8]] : i8 {
+ // CHECK: %[[ADD_I8:.*]] = arith.addi %[[IV]], %[[C1_I8]] : i8
+ // CHECK: %[[ADD_I64:.*]] = arith.extui %[[ADD_I8]] : i8 to i64
+ // CHECK: call @use_i64(%[[ADD_I64]])
+ scf.for %iv = %lb to %ub step %step : i64 {
+ %add = arith.addi %iv, %c1_i64 : i64
+ func.call @use_i64(%add) : (i64) -> ()
+ }
+ return
+}
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 13a9b1cd38d88..985e347fbf5ee 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -60,6 +60,19 @@ func.func @loop_for_single_block(%arg0: index) {
func.func @loop_for_single_index_argument(%arg0: index) {
// expected-error at +1 {{expected induction variable to be same type as bounds}}
+ "scf.for"(%arg0, %arg0, %arg0) (
+ {
+ ^bb0(%i0 : i32):
+ scf.yield
+ }
+ ) : (index, index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @loop_for_invalid_ind_type(%arg0: index) {
+ // expected-error at +1 {{0-th induction variable has invalid type: 'f32'}}
"scf.for"(%arg0, %arg0, %arg0) (
{
^bb0(%i0 : f32):
More information about the Mlir-commits
mailing list