[Mlir-commits] [mlir] 9bf7930 - [mlir][Arith] Let integer range narrowing handle negative values (#119642)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 13 15:38:49 PST 2024


Author: Krzysztof Drewniak
Date: 2024-12-13T17:38:46-06:00
New Revision: 9bf79308b893e8998e7efd752835636038c2db4f

URL: https://github.com/llvm/llvm-project/commit/9bf79308b893e8998e7efd752835636038c2db4f
DIFF: https://github.com/llvm/llvm-project/commit/9bf79308b893e8998e7efd752835636038c2db4f.diff

LOG: [mlir][Arith] Let integer range narrowing handle negative values (#119642)

Update integer range narrowing to handle negative values.

The previous restriction to only narrowing known-non-negative values
wasn't needed, as both the signed and unsigned ranges represent bounds
on the values of each variable in the program ... except that one might
be more accurate than the other. So, if either the signed or unsigned
interpretetation of the inputs and outputs allows for integer narrowing,
the narrowing is permitted.

This commit also updates the integer optimization rewrites to preserve
the stae of constant-like operations and those that are narrowed so that
rewrites of other operations don't lose that range information.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
    mlir/test/Dialect/Arith/int-range-narrowing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index efc4db7e4c9961..b54a53f5ef70ef 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -46,6 +46,17 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
   return inferredRange.getConstantValue();
 }
 
+static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
+                             Value newVal) {
+  assert(oldVal.getType() == newVal.getType() &&
+         "Can't copy integer ranges between 
diff erent types");
+  auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
+  if (!oldState)
+    return;
+  (void)solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(
+      *oldState);
+}
+
 /// Patterned after SCCP
 static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
                                               PatternRewriter &rewriter,
@@ -80,6 +91,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
   if (!constOp)
     return failure();
 
+  copyIntegerRange(solver, value, constOp->getResult(0));
   rewriter.replaceAllUsesWith(value, constOp->getResult(0));
   return success();
 }
@@ -195,56 +207,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
-/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
-static LogicalResult checkIntType(Type type, unsigned targetBitwidth) {
-  Type elemType = getElementTypeOrSelf(type);
-  if (isa<IndexType>(elemType))
-    return success();
-
-  if (auto intType = dyn_cast<IntegerType>(elemType))
-    if (intType.getWidth() > targetBitwidth)
-      return success();
-
-  return failure();
-}
-
-/// Check if op have same type for all operands and results and this type
-/// is suitable for truncation.
-static LogicalResult checkElementwiseOpType(Operation *op,
-                                            unsigned targetBitwidth) {
-  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
-    return failure();
-
-  Type type;
-  for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
-    if (!type) {
-      type = val.getType();
-      continue;
-    }
-
-    if (type != val.getType())
-      return failure();
-  }
-
-  return checkIntType(type, targetBitwidth);
-}
-
-/// Return union of all operands values ranges.
-static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
-                                                         ValueRange operands) {
-  std::optional<ConstantIntRanges> ret;
-  for (Value value : operands) {
+/// Gather ranges for all the values in `values`. Appends to the existing
+/// vector.
+static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
+                                   SmallVectorImpl<ConstantIntRanges> &ranges) {
+  for (Value val : values) {
     auto *maybeInferredRange =
-        solver.lookupState<IntegerValueRangeLattice>(value);
+        solver.lookupState<IntegerValueRangeLattice>(val);
     if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
-      return std::nullopt;
+      return failure();
 
     const ConstantIntRanges &inferredRange =
         maybeInferredRange->getValue().getValue();
-
-    ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
+    ranges.push_back(inferredRange);
   }
-  return ret;
+  return success();
 }
 
 /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
@@ -258,41 +235,59 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
   return dstType;
 }
 
-/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
-static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin,
-                                APInt smax, APInt umin, APInt umax) {
-  auto sge = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.sext(width);
-    val2 = val2.sext(width);
-    return val1.sge(val2);
-  };
-  auto sle = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.sext(width);
-    val2 = val2.sext(width);
-    return val1.sle(val2);
-  };
-  auto uge = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.zext(width);
-    val2 = val2.zext(width);
-    return val1.uge(val2);
-  };
-  auto ule = [](APInt val1, APInt val2) -> bool {
-    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
-    val1 = val1.zext(width);
-    val2 = val2.zext(width);
-    return val1.ule(val2);
-  };
-  return success(sge(range.smin(), smin) && sle(range.smax(), smax) &&
-                 uge(range.umin(), umin) && ule(range.umax(), umax));
+namespace {
+// Enum for tracking which type of truncation should be performed
+// to narrow an operation, if any.
+enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
+} // namespace
+
+/// If the values within `range` can be represented using only `width` bits,
+/// return the kind of truncation needed to preserve that property.
+///
+/// This check relies on the fact that the signed and unsigned ranges are both
+/// always correct, but that one might be an approximation of the other,
+/// so we want to use the correct truncation operation.
+static CastKind checkTruncatability(const ConstantIntRanges &range,
+                                    unsigned targetWidth) {
+  unsigned srcWidth = range.smin().getBitWidth();
+  if (srcWidth <= targetWidth)
+    return CastKind::None;
+  unsigned removedWidth = srcWidth - targetWidth;
+  // The sign bits need to extend into the sign bit of the target width. For
+  // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
+  // bits.
+  bool canTruncateSigned =
+      range.smin().getNumSignBits() >= (removedWidth + 1) &&
+      range.smax().getNumSignBits() >= (removedWidth + 1);
+  bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
+                             range.umax().countLeadingZeros() >= removedWidth;
+  if (canTruncateSigned && canTruncateUnsigned)
+    return CastKind::Both;
+  if (canTruncateSigned)
+    return CastKind::Signed;
+  if (canTruncateUnsigned)
+    return CastKind::Unsigned;
+  return CastKind::None;
+}
+
+static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
+  if (lhs == CastKind::None || rhs == CastKind::None)
+    return CastKind::None;
+  if (lhs == CastKind::Both)
+    return rhs;
+  if (rhs == CastKind::Both)
+    return lhs;
+  if (lhs == rhs)
+    return lhs;
+  return CastKind::None;
 }
 
-static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
+                    CastKind castKind) {
   Type srcType = src.getType();
   assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
          "Mixing vector and non-vector types");
+  assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
   Type srcElemType = getElementTypeOrSelf(srcType);
   Type dstElemType = getElementTypeOrSelf(dstType);
   assert(srcElemType.isIntOrIndex() && "Invalid src type");
@@ -300,14 +295,19 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
   if (srcType == dstType)
     return src;
 
-  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
+  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
+    if (castKind == CastKind::Signed)
+      return builder.create<arith::IndexCastOp>(loc, dstType, src);
     return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+  }
 
   auto srcInt = cast<IntegerType>(srcElemType);
   auto dstInt = cast<IntegerType>(dstElemType);
   if (dstInt.getWidth() < srcInt.getWidth())
     return builder.create<arith::TruncIOp>(loc, dstType, src);
 
+  if (castKind == CastKind::Signed)
+    return builder.create<arith::ExtSIOp>(loc, dstType, src);
   return builder.create<arith::ExtUIOp>(loc, dstType, src);
 }
 
@@ -319,36 +319,47 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    std::optional<ConstantIntRanges> range =
-        getOperandsRange(solver, op->getResults());
-    if (!range)
-      return failure();
+    if (op->getNumResults() == 0)
+      return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
+
+    SmallVector<ConstantIntRanges> ranges;
+    if (failed(collectRanges(solver, op->getOperands(), ranges)))
+      return rewriter.notifyMatchFailure(op, "input without specified range");
+    if (failed(collectRanges(solver, op->getResults(), ranges)))
+      return rewriter.notifyMatchFailure(op, "output without specified range");
+
+    Type srcType = op->getResult(0).getType();
+    if (!llvm::all_equal(op->getResultTypes()))
+      return rewriter.notifyMatchFailure(op, "mismatched result types");
+    if (op->getNumOperands() == 0 ||
+        !llvm::all_of(op->getOperandTypes(),
+                      [=](Type t) { return t == srcType; }))
+      return rewriter.notifyMatchFailure(
+          op, "no operands or operand types don't match result type");
 
     for (unsigned targetBitwidth : targetBitwidths) {
-      if (failed(checkElementwiseOpType(op, targetBitwidth)))
-        continue;
-
-      Type srcType = op->getResult(0).getType();
-
-      // We are truncating op args to the desired bitwidth before the op and
-      // then extending op results back to the original width after. extui and
-      // exti will produce 
diff erent results for negative values, so limit
-      // signed range to non-negative values.
-      auto smin = APInt::getZero(targetBitwidth);
-      auto smax = APInt::getSignedMaxValue(targetBitwidth);
-      auto umin = APInt::getMinValue(targetBitwidth);
-      auto umax = APInt::getMaxValue(targetBitwidth);
-      if (failed(checkRange(*range, smin, smax, umin, umax)))
+      CastKind castKind = CastKind::Both;
+      for (const ConstantIntRanges &range : ranges) {
+        castKind = mergeCastKinds(castKind,
+                                  checkTruncatability(range, targetBitwidth));
+        if (castKind == CastKind::None)
+          break;
+      }
+      if (castKind == CastKind::None)
         continue;
-
       Type targetType = getTargetType(srcType, targetBitwidth);
       if (targetType == srcType)
         continue;
 
       Location loc = op->getLoc();
       IRMapping mapping;
-      for (Value arg : op->getOperands()) {
-        Value newArg = doCast(rewriter, loc, arg, targetType);
+      for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
+        CastKind argCastKind = castKind;
+        // When dealing with `index` values, preserve non-negativity in the
+        // index_casts since we can't recover this in unsigned when equivalent.
+        if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
+          argCastKind = CastKind::Both;
+        Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
         mapping.map(arg, newArg);
       }
 
@@ -359,8 +370,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
         }
       });
       SmallVector<Value> newResults;
-      for (Value res : newOp->getResults())
-        newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+      for (auto [newRes, oldRes] :
+           llvm::zip_equal(newOp->getResults(), op->getResults())) {
+        Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
+        copyIntegerRange(solver, oldRes, castBack);
+        newResults.push_back(castBack);
+      }
 
       rewriter.replaceOp(op, newResults);
       return success();
@@ -382,21 +397,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
     Value lhs = op.getLhs();
     Value rhs = op.getRhs();
 
-    std::optional<ConstantIntRanges> range =
-        getOperandsRange(solver, {lhs, rhs});
-    if (!range)
+    SmallVector<ConstantIntRanges> ranges;
+    if (failed(collectRanges(solver, op.getOperands(), ranges)))
       return failure();
+    const ConstantIntRanges &lhsRange = ranges[0];
+    const ConstantIntRanges &rhsRange = ranges[1];
 
+    Type srcType = lhs.getType();
     for (unsigned targetBitwidth : targetBitwidths) {
-      Type srcType = lhs.getType();
-      if (failed(checkIntType(srcType, targetBitwidth)))
-        continue;
-
-      auto smin = APInt::getSignedMinValue(targetBitwidth);
-      auto smax = APInt::getSignedMaxValue(targetBitwidth);
-      auto umin = APInt::getMinValue(targetBitwidth);
-      auto umax = APInt::getMaxValue(targetBitwidth);
-      if (failed(checkRange(*range, smin, smax, umin, umax)))
+      CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
+      CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
+      CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
+      // Note: this includes target width > src width.
+      if (castKind == CastKind::None)
         continue;
 
       Type targetType = getTargetType(srcType, targetBitwidth);
@@ -405,12 +418,13 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
 
       Location loc = op->getLoc();
       IRMapping mapping;
-      for (Value arg : op->getOperands()) {
-        Value newArg = doCast(rewriter, loc, arg, targetType);
-        mapping.map(arg, newArg);
-      }
+      Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind);
+      Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind);
+      mapping.map(lhs, lhsCast);
+      mapping.map(rhs, rhsCast);
 
       Operation *newOp = rewriter.clone(*op, mapping);
+      copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
       rewriter.replaceOp(op, newOp->getResults());
       return success();
     }
@@ -425,19 +439,23 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
 /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
 /// This pattern assumes all passed `targetBitwidths` are not wider than index
 /// type.
-struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
+template <typename CastOp>
+struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
   FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
-      : OpRewritePattern(context), targetBitwidths(target) {}
+      : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
 
-  LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
+  LogicalResult matchAndRewrite(CastOp op,
                                 PatternRewriter &rewriter) const override {
-    auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
+    auto srcOp = op.getIn().template getDefiningOp<CastOp>();
     if (!srcOp)
-      return failure();
+      return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
 
     Value src = srcOp.getIn();
     if (src.getType() != op.getType())
-      return failure();
+      return rewriter.notifyMatchFailure(op, "outer types don't match");
+
+    if (!srcOp.getType().isIndex())
+      return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
 
     auto intType = dyn_cast<IntegerType>(op.getType());
     if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
@@ -517,7 +535,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
     ArrayRef<unsigned> bitwidthsSupported) {
   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
                                               bitwidthsSupported);
-  patterns.add<FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported);
+  patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
+               FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
+                                                       bitwidthsSupported);
 }
 
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {

diff  --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 8893f299177ceb..e16db6293560e0 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -4,9 +4,14 @@
 // Some basic tests
 //===----------------------------------------------------------------------===//
 
-// Do not truncate negative values
+// Truncate possibly-negative values in a signed way
 // CHECK-LABEL: func @test_addi_neg
-//       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+//       CHECK:  %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
+//       CHECK:  %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
+//       CHECK:  %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+//       CHECK:  %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+//       CHECK:  %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
+//       CHECK:  %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
 //       CHECK:  return %[[RES]] : index
 func.func @test_addi_neg() -> index {
   %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
@@ -146,14 +151,18 @@ func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   return %r : i32
 }
 
-// This case should not get optimized because of mixed extensions.
+// This can be optimized to i16 since we're dealing in [-128, 127] + [0, 255],
+// which is [-128, 382]
 //
 // CHECK-LABEL: func.func @addi_mixed_ext_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[ADD]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extsi %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
@@ -181,15 +190,15 @@ func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
 // arith.subi
 //===----------------------------------------------------------------------===//
 
-// This patterns should only apply to `arith.subi` ops with sign-extended
-// arguments.
-//
 // CHECK-LABEL: func.func @subi_extui_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[SUB]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[SUB]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extui %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
@@ -197,14 +206,17 @@ func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   return %r : i32
 }
 
-// This case should not get optimized because of mixed extensions.
+// Despite the mixed sign and zero extensions, we can optimize here
 //
 // CHECK-LABEL: func.func @subi_mixed_ext_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[ADD]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extsi %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
@@ -216,15 +228,14 @@ func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
 // arith.muli
 //===----------------------------------------------------------------------===//
 
-// TODO: This should be optimized into i16
 // CHECK-LABEL: func.func @muli_extui_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i24
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i24 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i16 to i32
 // CHECK-NEXT:    return %[[RET]] : i32
 func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extui %lhs : i8 to i32
@@ -249,17 +260,90 @@ func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
   return %r : i32
 }
 
-// This case should not get optimized because of mixed extensions.
+// The mixed extensions mean that we have [-128, 127] * [0, 255], which can
+// be computed exactly in i16.
 //
 // CHECK-LABEL: func.func @muli_mixed_ext_i8
 // CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
 // CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
 // CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[MUL]] : i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[MUL]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
 func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
   %a = arith.extsi %lhs : i8 to i32
   %b = arith.extui %rhs : i8 to i32
   %r = arith.muli %a, %b : i32
   return %r : i32
 }
+
+// Can't reduce width here since we need the extra bits
+// CHECK-LABEL: func.func @i32_overflows_to_index
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+// CHECK: %[[CLAMPED:.+]] = arith.maxsi %[[ARG0]], %{{.*}} : i32
+// CHECK: %[[CAST:.+]] = arith.index_castui %[[CLAMPED]] : i32 to index
+// CHECK: %[[MUL:.+]] = arith.muli %[[CAST]], %{{.*}} : index
+// CHECK: return %[[MUL]] : index
+func.func @i32_overflows_to_index(%arg0: i32) -> index {
+  %c0_i32 = arith.constant 0 : i32
+  %c4 = arith.constant 4 : index
+  %clamped = arith.maxsi %arg0, %c0_i32 : i32
+  %cast = arith.index_castui %clamped : i32 to index
+  %mul = arith.muli %cast, %c4 : index
+  return %mul : index
+}
+
+// Can't reduce width here since we need the extra bits
+// CHECK-LABEL: func.func @i32_overflows_to_i64
+// CHECK-SAME: (%[[ARG0:.+]]: i32)
+// CHECK: %[[CLAMPED:.+]] = arith.maxsi %[[ARG0]], %{{.*}} : i32
+// CHECK: %[[CAST:.+]] = arith.extui %[[CLAMPED]] : i32 to i64
+// CHECK: %[[MUL:.+]] = arith.muli %[[CAST]], %{{.*}} : i64
+// CHECK: return %[[MUL]] : i64
+func.func @i32_overflows_to_i64(%arg0: i32) -> i64 {
+  %c0_i32 = arith.constant 0 : i32
+  %c4_i64 = arith.constant 4 : i64
+  %clamped = arith.maxsi %arg0, %c0_i32 : i32
+  %cast = arith.extui %clamped : i32 to i64
+  %mul = arith.muli %cast, %c4_i64 : i64
+  return %mul : i64
+}
+
+// Motivating example for negative number support, added as a test case
+// and simplified
+// CHECK-LABEL: func.func @clamp_to_loop_bound_and_id()
+// CHECK: %[[TID:.+]] = test.with_bounds
+// CHECK-SAME: umax = 63
+// CHECK: %[[BOUND:.+]] = test.with_bounds
+// CHECK-SAME: umax = 112
+// CHECK: scf.for %[[ARG0:.+]] = %{{.*}} to %[[BOUND]] step %{{.*}}
+// CHECK-DAG:   %[[BOUND_I8:.+]] = arith.index_castui %[[BOUND]] : index to i8
+// CHECK-DAG:   %[[ARG0_I8:.+]] = arith.index_castui %[[ARG0]] : index to i8
+//     CHECK:   %[[V0_I8:.+]] = arith.subi %[[BOUND_I8]], %[[ARG0_I8]] : i8
+//     CHECK:   %[[V1_I8:.+]] = arith.minsi %[[V0_I8]], %{{.*}} : i8
+//     CHECK:   %[[V1_INDEX:.+]] = arith.index_cast %[[V1_I8]] : i8 to index
+//     CHECK:   %[[V1_I16:.+]] = arith.index_cast %[[V1_INDEX]] : index to i16
+//     CHECK:   %[[TID_I16:.+]] = arith.index_castui %[[TID]] : index to i16
+//     CHECK:   %[[V2_I16:.+]] = arith.subi %[[V1_I16]], %[[TID_I16]] : i16
+//     CHECK:   %[[V3:.+]] = arith.cmpi slt, %[[V2_I16]], %{{.*}} : i16
+//     CHECK:   scf.if %[[V3]]
+func.func @clamp_to_loop_bound_and_id() {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c64 = arith.constant 64 : index
+
+  %tid = test.with_bounds {smin = 0 : index, smax = 63 : index, umin = 0 : index, umax = 63 : index} : index
+  %bound = test.with_bounds {smin = 16 : index, smax = 112 : index, umin = 16 : index, umax = 112 : index} : index
+  scf.for %arg0 = %c16 to %bound step %c64 {
+    %0 = arith.subi %bound, %arg0 : index
+    %1 = arith.minsi %0, %c64 : index
+    %2 = arith.subi %1, %tid : index
+    %3 = arith.cmpi slt, %2, %c0 : index
+    scf.if %3 {
+      vector.print str "sideeffect"
+    }
+  }
+  return
+}


        


More information about the Mlir-commits mailing list