[Mlir-commits] [mlir] [mlir][Arith] Let integer range narrowing handle negative values (PR #119642)

Krzysztof Drewniak llvmlistbot at llvm.org
Fri Dec 13 11:06:00 PST 2024


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/119642

>From 367fbe8be65b0b1669a83a23b3482126785d1c76 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 11 Dec 2024 21:31:59 +0000
Subject: [PATCH 1/2] [mlir][Arith] Let integer range narrowing handle negative
 values

Update integer range narrowing to handle negative values.

The previous restriction to only narrowing known-non-negative values
wasn't needed, as both the signed and unsigned ranges represent bounds
on the values of each variable in the program ... except that one
might be more accurate than the other. So, if either the signed or
unsigned interpretetation of the inputs and outputs allows for integer
narrowing, the narrowing is permitted.

This commit also updates the integer optimization rewrites to preserve
the stae of constant-like operations and those that are narrowed so
that rewrites of other operations don't lose that range information.
---
 .../Transforms/IntRangeOptimizations.cpp      | 255 ++++++++++--------
 .../Dialect/Arith/int-range-narrowing.mlir    | 126 +++++++--
 2 files changed, 242 insertions(+), 139 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index efc4db7e4c9961..7619964304720f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -46,6 +46,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
   return inferredRange.getConstantValue();
 }
 
+static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
+                             Value newVal) {
+  assert(oldVal.getType() == newVal.getType() &&
+         "Can't copy integer ranges between different types");
+  auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
+  if (!oldState)
+    return;
+  solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(*oldState);
+}
+
 /// Patterned after SCCP
 static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
                                               PatternRewriter &rewriter,
@@ -80,6 +90,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
   if (!constOp)
     return failure();
 
+  copyIntegerRange(solver, value, constOp->getResult(0));
   rewriter.replaceAllUsesWith(value, constOp->getResult(0));
   return success();
 }
@@ -195,56 +206,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
-/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
-static LogicalResult checkIntType(Type type, unsigned targetBitwidth) {
-  Type elemType = getElementTypeOrSelf(type);
-  if (isa<IndexType>(elemType))
-    return success();
-
-  if (auto intType = dyn_cast<IntegerType>(elemType))
-    if (intType.getWidth() > targetBitwidth)
-      return success();
-
-  return failure();
-}
-
-/// Check if op have same type for all operands and results and this type
-/// is suitable for truncation.
-static LogicalResult checkElementwiseOpType(Operation *op,
-                                            unsigned targetBitwidth) {
-  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
-    return failure();
-
-  Type type;
-  for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
-    if (!type) {
-      type = val.getType();
-      continue;
-    }
-
-    if (type != val.getType())
-      return failure();
-  }
-
-  return checkIntType(type, targetBitwidth);
-}
-
-/// Return union of all operands values ranges.
-static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
-                                                         ValueRange operands) {
-  std::optional<ConstantIntRanges> ret;
-  for (Value value : operands) {
+/// Gather ranges for all the values in `values`. Appends to the existing
+/// vector.
+static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
+                                   SmallVectorImpl<ConstantIntRanges> &ranges) {
+  for (Value val : values) {
     auto *maybeInferredRange =
-        solver.lookupState<IntegerValueRangeLattice>(value);
+        solver.lookupState<IntegerValueRangeLattice>(val);
     if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
-      return std::nullopt;
+      return failure();
 
     const ConstantIntRanges &inferredRange =
         maybeInferredRange->getValue().getValue();
-
-    ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
+    ranges.push_back(inferredRange);
   }
-  return ret;
+  return success();
 }
 
 /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
@@ -258,41 +234,59 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
   return dstType;
 }
 
-/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
-static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin,
-                                APInt smax, APInt umin, APInt umax) {
-  auto sge = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.sext(width);
-    val2 = val2.sext(width);
-    return val1.sge(val2);
-  };
-  auto sle = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.sext(width);
-    val2 = val2.sext(width);
-    return val1.sle(val2);
-  };
-  auto uge = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.zext(width);
-    val2 = val2.zext(width);
-    return val1.uge(val2);
-  };
-  auto ule = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.zext(width);
-    val2 = val2.zext(width);
-    return val1.ule(val2);
-  };
-  return success(sge(range.smin(), smin) && sle(range.smax(), smax) &&
-                 uge(range.umin(), umin) && ule(range.umax(), umax));
+namespace {
+// Enum for tracking which type of truncation should be performed
+// to narrow an operation, if any.
+enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
+} // namespace
+
+/// If the values within `range` can be represented using only `width` bits,
+/// return the kind of truncation needed to preserve that property.
+///
+/// This check relies on the fact that the signed and unsigned ranges are both
+/// always correct, but that one might be an approximation of the other,
+/// so we want to use the correct truncation operation.
+static CastKind checkTruncatability(const ConstantIntRanges &range,
+                                    unsigned targetWidth) {
+  unsigned srcWidth = range.smin().getBitWidth();
+  if (srcWidth <= targetWidth)
+    return CastKind::None;
+  unsigned removedWidth = srcWidth - targetWidth;
+  // The sign bits need to extend into the sign bit of the target width. For
+  // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
+  // bits.
+  bool canTruncateSigned =
+      range.smin().getNumSignBits() >= (removedWidth + 1) &&
+      range.smax().getNumSignBits() >= (removedWidth + 1);
+  bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
+                             range.umax().countLeadingZeros() >= removedWidth;
+  if (canTruncateSigned && canTruncateUnsigned)
+    return CastKind::Both;
+  if (canTruncateSigned)
+    return CastKind::Signed;
+  if (canTruncateUnsigned)
+    return CastKind::Unsigned;
+  return CastKind::None;
+}
+
+static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
+  if (lhs == CastKind::None || rhs == CastKind::None)
+    return CastKind::None;
+  if (lhs == CastKind::Both)
+    return rhs;
+  if (rhs == CastKind::Both)
+    return lhs;
+  if (lhs == rhs)
+    return lhs;
+  return CastKind::None;
 }
 
-static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
+                    CastKind castKind) {
   Type srcType = src.getType();
   assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
          "Mixing vector and non-vector types");
+  assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
   Type srcElemType = getElementTypeOrSelf(srcType);
   Type dstElemType = getElementTypeOrSelf(dstType);
   assert(srcElemType.isIntOrIndex() && "Invalid src type");
@@ -300,14 +294,19 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
   if (srcType == dstType)
     return src;
 
-  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
+  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
+    if (castKind == CastKind::Signed)
+      return builder.create<arith::IndexCastOp>(loc, dstType, src);
     return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+  }
 
   auto srcInt = cast<IntegerType>(srcElemType);
   auto dstInt = cast<IntegerType>(dstElemType);
   if (dstInt.getWidth() < srcInt.getWidth())
     return builder.create<arith::TruncIOp>(loc, dstType, src);
 
+  if (castKind == CastKind::Signed)
+    return builder.create<arith::ExtSIOp>(loc, dstType, src);
   return builder.create<arith::ExtUIOp>(loc, dstType, src);
 }
 
@@ -319,36 +318,46 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    std::optional<ConstantIntRanges> range =
-        getOperandsRange(solver, op->getResults());
-    if (!range)
-      return failure();
+    SmallVector<ConstantIntRanges> ranges;
+    if (failed(collectRanges(solver, op->getOperands(), ranges)))
+      return rewriter.notifyMatchFailure(op, "input without specified range");
+    if (failed(collectRanges(solver, op->getResults(), ranges)))
+      return rewriter.notifyMatchFailure(op, "output without specified range");
+
+    if (op->getNumResults() == 0)
+      return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
+    Type srcType = op->getResult(0).getType();
+    if (!llvm::all_equal(op->getResultTypes()))
+      return rewriter.notifyMatchFailure(op, "mismatched result types");
+    if (op->getNumOperands() == 0 ||
+        !llvm::all_of(op->getOperandTypes(),
+                      [=](Type t) { return t == srcType; }))
+      return rewriter.notifyMatchFailure(
+          op, "no operands or operand types don't match result type");
 
     for (unsigned targetBitwidth : targetBitwidths) {
-      if (failed(checkElementwiseOpType(op, targetBitwidth)))
-        continue;
-
-      Type srcType = op->getResult(0).getType();
-
-      // We are truncating op args to the desired bitwidth before the op and
-      // then extending op results back to the original width after. extui and
-      // exti will produce different results for negative values, so limit
-      // signed range to non-negative values.
-      auto smin = APInt::getZero(targetBitwidth);
-      auto smax = APInt::getSignedMaxValue(targetBitwidth);
-      auto umin = APInt::getMinValue(targetBitwidth);
-      auto umax = APInt::getMaxValue(targetBitwidth);
-      if (failed(checkRange(*range, smin, smax, umin, umax)))
+      CastKind castKind = CastKind::Both;
+      for (const ConstantIntRanges &range : ranges) {
+        castKind = mergeCastKinds(castKind,
+                                  checkTruncatability(range, targetBitwidth));
+        if (castKind == CastKind::None)
+          break;
+      }
+      if (castKind == CastKind::None)
         continue;
-
       Type targetType = getTargetType(srcType, targetBitwidth);
       if (targetType == srcType)
         continue;
 
       Location loc = op->getLoc();
       IRMapping mapping;
-      for (Value arg : op->getOperands()) {
-        Value newArg = doCast(rewriter, loc, arg, targetType);
+      for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
+        CastKind argCastKind = castKind;
+        // When dealing with `index` values, preserve non-negativity in the
+        // index_casts since we can't recover this in unsigned when equivalent.
+        if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
+          argCastKind = CastKind::Both;
+        Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
         mapping.map(arg, newArg);
       }
 
@@ -359,8 +368,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
         }
       });
       SmallVector<Value> newResults;
-      for (Value res : newOp->getResults())
-        newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+      for (auto [newRes, oldRes] :
+           llvm::zip_equal(newOp->getResults(), op->getResults())) {
+        Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
+        copyIntegerRange(solver, oldRes, castBack);
+        newResults.push_back(castBack);
+      }
 
       rewriter.replaceOp(op, newResults);
       return success();
@@ -382,21 +395,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
     Value lhs = op.getLhs();
     Value rhs = op.getRhs();
 
-    std::optional<ConstantIntRanges> range =
-        getOperandsRange(solver, {lhs, rhs});
-    if (!range)
+    SmallVector<ConstantIntRanges> ranges;
+    if (failed(collectRanges(solver, op.getOperands(), ranges)))
       return failure();
+    const ConstantIntRanges &lhsRange = ranges[0];
+    const ConstantIntRanges &rhsRange = ranges[1];
 
+    Type srcType = lhs.getType();
     for (unsigned targetBitwidth : targetBitwidths) {
-      Type srcType = lhs.getType();
-      if (failed(checkIntType(srcType, targetBitwidth)))
-        continue;
-
-      auto smin = APInt::getSignedMinValue(targetBitwidth);
-      auto smax = APInt::getSignedMaxValue(targetBitwidth);
-      auto umin = APInt::getMinValue(targetBitwidth);
-      auto umax = APInt::getMaxValue(targetBitwidth);
-      if (failed(checkRange(*range, smin, smax, umin, umax)))
+      CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
+      CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
+      CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
+      // Note: this includes target width > src width.
+      if (castKind == CastKind::None)
         continue;
 
       Type targetType = getTargetType(srcType, targetBitwidth);
@@ -406,11 +417,12 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
       Location loc = op->getLoc();
       IRMapping mapping;
       for (Value arg : op->getOperands()) {
-        Value newArg = doCast(rewriter, loc, arg, targetType);
+        Value newArg = doCast(rewriter, loc, arg, targetType, castKind);
         mapping.map(arg, newArg);
       }
 
       Operation *newOp = rewriter.clone(*op, mapping);
+      copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
       rewriter.replaceOp(op, newOp->getResults());
       return success();
     }
@@ -425,19 +437,24 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
 /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
 /// This pattern assumes all passed `targetBitwidths` are not wider than index
 /// type.
-struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
+template <typename CastOp>
+struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
   FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
-      : OpRewritePattern(context), targetBitwidths(target) {}
+      : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
 
-  LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
+  LogicalResult matchAndRewrite(CastOp op,
                                 PatternRewriter &rewriter) const override {
-    auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
+    auto srcOp = op.getIn().template getDefiningOp<CastOp>();
     if (!srcOp)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
+    ;
 
     Value src = srcOp.getIn();
     if (src.getType() != op.getType())
-      return failure();
+      return rewriter.notifyMatchFailure(op, "outer types don't match");
+
+    if (!srcOp.getType().isIndex())
+      return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
 
     auto intType = dyn_cast<IntegerType>(op.getType());
     if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
@@ -517,7 +534,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
     ArrayRef<unsigned> bitwidthsSupported) {
   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
                                               bitwidthsSupported);
-  patterns.add<FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported);
+  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
+               FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
+                                                       bitwidthsSupported);
 }
 
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 8893f299177ceb..ac0132480e4e18 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -4,9 +4,14 @@
 // Some basic tests
 //===----------------------------------------------------------------------===//
 
-// Do not truncate negative values
+// Truncate possibly-negative values in a signed way
 // CHECK-LABEL: func @test_addi_neg
-//       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+//       CHECK:  %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
+//       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+//       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
+//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
 //       CHECK:  return %[[RES]] : index
 func.func @test_addi_neg() -> index {
   %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
@@ -146,14 +151,18 @@ func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   return %r : i32
 }
 
-// This case should not get optimized because of mixed extensions.
+// This can be optimized to i16 since we're dealing in [-128, 127] + [0, 255],
+// which is [-128, 382]
 //
 // CHECK-LABEL: func.func @addi_mixed_ext_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[ADD]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extsi %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
@@ -181,15 +190,15 @@ func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
 // arith.subi
 //===----------------------------------------------------------------------===//
 
-// This patterns should only apply to `arith.subi` ops with sign-extended
-// arguments.
-//
 // CHECK-LABEL: func.func @subi_extui_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[SUB]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[SUB]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extui %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
@@ -197,14 +206,17 @@ func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   return %r : i32
 }
 
-// This case should not get optimized because of mixed extensions.
+// Despite the mixed sign and zero extensions, we can optimize here
 //
 // CHECK-LABEL: func.func @subi_mixed_ext_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[ADD]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extsi %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
@@ -216,15 +228,14 @@ func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
 // arith.muli
 //===----------------------------------------------------------------------===//
 
-// TODO: This should be optimized into i16
 // CHECK-LABEL: func.func @muli_extui_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i24
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i24 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i16 to i32
 // CHECK-NEXT:    return %[[RET]] : i32
 func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extui %lhs : i8 to i32
@@ -249,17 +260,90 @@ func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
   return %r : i32
 }
 
-// This case should not get optimized because of mixed extensions.
+// The mixed extensions mean that we have [-128, 127] * [0, 255], which can
+// be computed exactly in i16.
 //
 // CHECK-LABEL: func.func @muli_mixed_ext_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[MUL]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[MUL]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extsi %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
   %r = arith.muli %a, %b : i32
   return %r : i32
 }
+
+// Can't reduce width here since we need the extra bits
+// CHECK-LABEL: func.func @i32_overflows_to_index
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+// CHECK: %[[CLAMPED:.+]] = arith.maxsi %[[ARG0]], %{{.*}} : i32
+// CHECK: %[[CAST:.+]] = arith.index_castui %[[CLAMPED]] : i32 to index
+// CHECK: %[[MUL:.+]] = arith.muli %[[CAST]], %{{.*}} : index
+// CHECK: return %[[MUL]] : index
+func.func @i32_overflows_to_index(%arg0: i32) -> index {
+  %c0_i32 = arith.constant 0 : i32
+  %c4 = arith.constant 4 : index
+  %clamped = arith.maxsi %arg0, %c0_i32 : i32
+  %cast = arith.index_castui %clamped : i32 to index
+  %mul = arith.muli %cast, %c4 : index
+  return %mul : index
+}
+
+// Can't reduce width here since we need the extra bits
+// CHECK-LABEL: func.func @i32_overflows_to_i64
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+// CHECK: %[[CLAMPED:.+]] = arith.maxsi %[[ARG0]], %{{.*}} : i32
+// CHECK: %[[CAST:.+]] = arith.extui %[[CLAMPED]] : i32 to i64
+// CHECK: %[[MUL:.+]] = arith.muli %[[CAST]], %{{.*}} : i64
+// CHECK: return %[[MUL]] : i64
+func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i64 = arith.constant 4 : i64
+  %clamped = arith.maxsi %arg0, %c0_i32 : i32
+  %cast = arith.extui %clamped : i32 to i64
+  %mul = arith.muli %cast, %c4_i64 : i64
+  return %mul : i64
+}
+
+// Motivating example for negatative number support, added as a test case
+// and simplified
+// CHECK-LABEL: func.func @clamp_to_loop_bound_and_id()
+// CHECK: %[[TID:.+]] = test.with_bounds
+// CHECK-SAME: umax = 63
+// CHECK: %[[BOUND:.+]] = test.with_bounds
+// CHECK-SAME: umax = 112
+// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND]] step %{{.*}}
+// CHECK-DAG:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
+// CHECK-DAG:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0]] : index to i8
+//     CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
+//     CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %{{.*}} : i8
+//     CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
+//     CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+//     CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+//     CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
+//     CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %{{.*}} : i16
+//     CHECK:   scf.if %[[V3]]
+func.func @clamp_to_loop_bound_and_id() {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c64 = arith.constant 64 : index
+
+  %tid = test.with_bounds {smin = 0 : index, smax = 63 : index, umin = 0 : index, umax = 63 : index} : index
+  %bound = test.with_bounds {smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index} : index
+  scf.for %arg0 = %c16 to %bound step %c64 {
+    %0 = arith.subi %bound, %arg0 : index
+    %1 = arith.minsi %0, %c64 : index
+    %2 = arith.subi %1, %tid : index
+    %3 = arith.cmpi slt, %2, %c0 : index
+    scf.if %3 {
+      vector.print str "sideeffect"
+    }
+  }
+  return
+}

>From da394b2bef50645a0af891966c1e358b2172db42 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Fri, 13 Dec 2024 19:05:46 +0000
Subject: [PATCH 2/2] Review feedback, fix warnings

---
 .../Arith/Transforms/IntRangeOptimizations.cpp  | 17 +++++++++--------
 .../test/Dialect/Arith/int-range-narrowing.mlir |  2 +-
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 7619964304720f..b54a53f5ef70ef 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -53,7 +53,8 @@ static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
   auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
   if (!oldState)
     return;
-  solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(*oldState);
+  (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
+      *oldState);
 }
 
 /// Patterned after SCCP
@@ -318,14 +319,15 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
+    if (op->getNumResults() == 0)
+      return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
+
     SmallVector<ConstantIntRanges> ranges;
     if (failed(collectRanges(solver, op->getOperands(), ranges)))
       return rewriter.notifyMatchFailure(op, "input without specified range");
     if (failed(collectRanges(solver, op->getResults(), ranges)))
       return rewriter.notifyMatchFailure(op, "output without specified range");
 
-    if (op->getNumResults() == 0)
-      return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
     Type srcType = op->getResult(0).getType();
     if (!llvm::all_equal(op->getResultTypes()))
       return rewriter.notifyMatchFailure(op, "mismatched result types");
@@ -416,10 +418,10 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
 
       Location loc = op->getLoc();
       IRMapping mapping;
-      for (Value arg : op->getOperands()) {
-        Value newArg = doCast(rewriter, loc, arg, targetType, castKind);
-        mapping.map(arg, newArg);
-      }
+      Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
+      Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
+      mapping.map(lhs, lhsCast);
+      mapping.map(rhs, rhsCast);
 
       Operation *newOp = rewriter.clone(*op, mapping);
       copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
@@ -447,7 +449,6 @@ struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
     auto srcOp = op.getIn().template getDefiningOp<CastOp>();
     if (!srcOp)
       return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
-    ;
 
     Value src = srcOp.getIn();
     if (src.getType() != op.getType())
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index ac0132480e4e18..e16db6293560e0 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -311,7 +311,7 @@ func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
   return %mul : i64
 }
 
-// Motivating example for negatative number support, added as a test case
+// Motivating example for negative number support, added as a test case
 // and simplified
 // CHECK-LABEL: func.func @clamp_to_loop_bound_and_id()
 // CHECK: %[[TID:.+]] = test.with_bounds



More information about the Mlir-commits mailing list