[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)

Ivan Butygin llvmlistbot at llvm.org
Sun Nov 3 04:56:13 PST 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/112404

>From 065c3fbfba855a419313d2a9e003a49298420a47 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 16:32:32 +0200
Subject: [PATCH 01/19] [mlir] Add `arith-int-range-narrowing` pass

This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program.
They ideally should be unified in the future, but it's a task for the future.
---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   6 +
 .../mlir/Dialect/Arith/Transforms/Passes.td   |  19 ++
 .../Transforms/IntRangeOptimizations.cpp      | 260 ++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    |  57 ++++
 4 files changed, 342 insertions(+)
 create mode 100644 mlir/test/Dialect/Arith/int-range-narrowing.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index e866ac518dbbcb..b66bfbb23bc60c 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -77,6 +77,12 @@ void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
+/// Add patterns for int range based norrowing.
+void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                       DataFlowSolver &solver,
+                                       unsigned targetBitwidth);
+
+// TODO: merge these two narrowing passes.
 /// Add patterns for integer bitwidth narrowing.
 void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
                                        const ArithIntNarrowingOptions &options);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 1517f71f1a7c90..8c565f6489638e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -50,6 +50,25 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   ];
 }
 
+def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
+  let summary = "Reduce integer operations bitwidth based on integer range analysis";
+  let description = [{
+    This pass runs integer range analysis and tries to narrow arith ops to the
+    specified bitwidth based on its results.
+  }];
+
+  let options = [
+    Option<"targetBitwidth", "target-bitwidth", "unsigned",
+           /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+  ];
+
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
+}
+
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
   let summary = "Emulate operations on unsupported floats with extf/truncf";
   let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index d494bba081f801..005033dfd5e118 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -24,6 +26,9 @@
 namespace mlir::arith {
 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+
+#define GEN_PASS_DEF_ARITHINTRANGENARROWING
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 } // namespace mlir::arith
 
 using namespace mlir;
@@ -190,6 +195,223 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
+static Type checkArithType(Type type, unsigned targetBitwidth) {
+  type = getElementTypeOrSelf(type);
+  if (isa<IndexType>(type))
+    return type;
+
+  if (auto intType = dyn_cast<IntegerType>(type))
+    if (intType.getWidth() > targetBitwidth)
+      return type;
+
+  return nullptr;
+}
+
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return nullptr;
+
+  Type type;
+  for (auto range :
+       {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+    for (Value val : range) {
+      if (!type) {
+        type = val.getType();
+        continue;
+      } else if (type != val.getType()) {
+        return nullptr;
+      }
+    }
+  }
+
+  return checkArithType(type, targetBitwidth);
+}
+
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+                                                         ValueRange results) {
+  std::optional<ConstantIntRanges> ret;
+  for (Value value : results) {
+    auto *maybeInferredRange =
+        solver.lookupState<IntegerValueRangeLattice>(value);
+    if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+      return std::nullopt;
+
+    const ConstantIntRanges &inferredRange =
+        maybeInferredRange->getValue().getValue();
+
+    if (!ret) {
+      ret = inferredRange;
+    } else {
+      ret = ret->rangeUnion(inferredRange);
+    }
+  }
+  return ret;
+}
+
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+  if (auto shaped = dyn_cast<ShapedType>(srcType))
+    return shaped.clone(dstType);
+
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  return dstType;
+}
+
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+                       APInt umin, APInt umax) {
+  auto sge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sge(val2);
+  };
+  auto sle = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sle(val2);
+  };
+  auto uge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.uge(val2);
+  };
+  auto ule = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.ule(val2);
+  };
+  return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+         uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+  Type srcType = src.getType();
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  assert(dstType.isIntOrIndex() && "Invalid dst type");
+  if (srcType == dstType)
+    return src;
+
+  if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+    return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+  auto srcInt = cast<IntegerType>(srcType);
+  auto dstInt = cast<IntegerType>(dstType);
+  if (dstInt.getWidth() < srcInt.getWidth()) {
+    return builder.create<arith::TruncIOp>(loc, dstType, src);
+  } else {
+    return builder.create<arith::ExtUIOp>(loc, dstType, src);
+  }
+}
+
+struct NarrowElementwise final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+      : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
+        targetBitwidth(target) {}
+
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    Type srcType = checkElementwiseOpType(op, targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, op->getResults());
+    if (!range)
+      return failure();
+
+    // We are truncating op args to the desired bitwidth before the op and then
+    // extending op results back to the original width after.
+    // extui and exti will produce different results for negative values, so
+    // limit signed range to non-negative values.
+    auto smin = APInt::getZero(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.modifyOpInPlace(newOp, [&]() {
+      for (OpResult res : newOp->getResults()) {
+        res.setType(targetType);
+      }
+    });
+    SmallVector<Value> newResults;
+    for (Value res : newOp->getResults())
+      newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
+struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
+  NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
+             unsigned target)
+      : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+
+  LogicalResult matchAndRewrite(arith::CmpIOp op,
+                                PatternRewriter &rewriter) const override {
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, {lhs, rhs});
+    if (!range)
+      return failure();
+
+    auto smin = APInt::getSignedMinValue(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
 struct IntRangeOptimizationsPass
     : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -214,6 +436,32 @@ struct IntRangeOptimizationsPass
       signalPassFailure();
   }
 };
+
+struct IntRangeNarrowingPass
+    : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
+  using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<IntegerRangeAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    DataFlowListener listener(solver);
+
+    RewritePatternSet patterns(ctx);
+    populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+
+    GreedyRewriteConfig config;
+    config.listener = &listener;
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+      signalPassFailure();
+  }
+};
 } // namespace
 
 void mlir::arith::populateIntRangeOptimizationsPatterns(
@@ -222,6 +470,18 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
 }
 
+void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                                    DataFlowSolver &solver,
+                                                    unsigned targetBitwidth) {
+  // Cmpi uses args ranges instead of results, run it with higher benefit,
+  // as its argumens can be potentially replaced.
+  patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
+                           targetBitwidth);
+
+  patterns.add<NarrowElementwise>(patterns.getContext(), solver,
+                                  targetBitwidth);
+}
+
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
new file mode 100644
index 00000000000000..d85cb3384061be
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32"  %s | FileCheck %s
+
+// Do not truncate negative values
+// CHECK-LABEL: func @test_addi_neg
+//       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+//       CHECK:  return %[[RES]] : index
+func.func @test_addi_neg() -> index {
+  %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = -1 : index, smin = -1 : index, smax = 0 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+// CHECK-LABEL: func @test_addi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+//       CHECK:  return %[[RES_CASTED]] : index
+func.func @test_addi() -> index {
+  %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+  %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+
+// CHECK-LABEL: func @test_addi_i64
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
+//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+//       CHECK:  return %[[RES_CASTED]] : i64
+func.func @test_addi_i64() -> i64 {
+  %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
+  %1 = test.with_bounds { umin = 6 : i64, umax = 7 : i64, smin = 6 : i64, smax = 7 : i64 } : i64
+  %2 = arith.addi %0, %1 : i64
+  return %2 : i64
+}
+
+// CHECK-LABEL: func @test_cmpi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  return %[[RES]] : i1
+func.func @test_cmpi() -> i1 {
+  %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %2 = arith.cmpi slt, %0, %1 : index
+  return %2 : i1
+}

>From ed4920c90d685aff762063692c11381727afdd44 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 18:59:07 +0200
Subject: [PATCH 02/19] typo

---
 mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index b66bfbb23bc60c..da7ded699f21c7 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -77,7 +77,7 @@ void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
-/// Add patterns for int range based norrowing.
+/// Add patterns for int range based narrowing.
 void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
                                        unsigned targetBitwidth);

>From eb91b30450e72b8ac8b17f198a941b255af01b5d Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 13:05:29 +0200
Subject: [PATCH 03/19] use list of bitwidths instead of 1

---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   2 +-
 .../mlir/Dialect/Arith/Transforms/Passes.td   |   4 +-
 .../Transforms/IntRangeOptimizations.cpp      | 141 ++++++++++--------
 .../Dialect/Arith/int-range-narrowing.mlir    |   2 +-
 4 files changed, 79 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index da7ded699f21c7..b6a87e88c6efbb 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -80,7 +80,7 @@ std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 /// Add patterns for int range based narrowing.
 void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
-                                       unsigned targetBitwidth);
+                                       ArrayRef<unsigned> bitwidthsSupported);
 
 // TODO: merge these two narrowing passes.
 /// Add patterns for integer bitwidth narrowing.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 8c565f6489638e..898d74249af618 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -58,8 +58,8 @@ def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
   }];
 
   let options = [
-    Option<"targetBitwidth", "target-bitwidth", "unsigned",
-           /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+    ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
+               "Integer bitwidths supported">,
   ];
 
   // Explicitly depend on "arith" because this pass could create operations in
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 005033dfd5e118..8c651076df2e5a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -308,108 +308,117 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
 
 struct NarrowElementwise final
     : public OpTraitRewritePattern<OpTrait::Elementwise> {
-  NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+  NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
+                    ArrayRef<unsigned> target)
       : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
-        targetBitwidth(target) {}
+        targetBitwidths(target) {}
 
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    Type srcType = checkElementwiseOpType(op, targetBitwidth);
-    if (!srcType)
-      return failure();
 
     std::optional<ConstantIntRanges> range =
         getOperandsRange(solver, op->getResults());
     if (!range)
       return failure();
 
-    // We are truncating op args to the desired bitwidth before the op and then
-    // extending op results back to the original width after.
-    // extui and exti will produce different results for negative values, so
-    // limit signed range to non-negative values.
-    auto smin = APInt::getZero(targetBitwidth);
-    auto smax = APInt::getSignedMaxValue(targetBitwidth);
-    auto umin = APInt::getMinValue(targetBitwidth);
-    auto umax = APInt::getMaxValue(targetBitwidth);
-    if (!checkRange(*range, smin, smax, umin, umax))
-      return failure();
+    for (unsigned targetBitwidth : targetBitwidths) {
+      Type srcType = checkElementwiseOpType(op, targetBitwidth);
+      if (!srcType)
+        continue;
 
-    Type targetType = getTargetType(srcType, targetBitwidth);
-    if (targetType == srcType)
-      return failure();
+      // We are truncating op args to the desired bitwidth before the op and
+      // then extending op results back to the original width after. extui and
+      // exti will produce different results for negative values, so limit
+      // signed range to non-negative values.
+      auto smin = APInt::getZero(targetBitwidth);
+      auto smax = APInt::getSignedMaxValue(targetBitwidth);
+      auto umin = APInt::getMinValue(targetBitwidth);
+      auto umax = APInt::getMaxValue(targetBitwidth);
+      if (!checkRange(*range, smin, smax, umin, umax))
+        continue;
 
-    Location loc = op->getLoc();
-    IRMapping mapping;
-    for (Value arg : op->getOperands()) {
-      Value newArg = doCast(rewriter, loc, arg, targetType);
-      mapping.map(arg, newArg);
-    }
+      Type targetType = getTargetType(srcType, targetBitwidth);
+      if (targetType == srcType)
+        continue;
 
-    Operation *newOp = rewriter.clone(*op, mapping);
-    rewriter.modifyOpInPlace(newOp, [&]() {
-      for (OpResult res : newOp->getResults()) {
-        res.setType(targetType);
+      Location loc = op->getLoc();
+      IRMapping mapping;
+      for (Value arg : op->getOperands()) {
+        Value newArg = doCast(rewriter, loc, arg, targetType);
+        mapping.map(arg, newArg);
       }
-    });
-    SmallVector<Value> newResults;
-    for (Value res : newOp->getResults())
-      newResults.emplace_back(doCast(rewriter, loc, res, srcType));
 
-    rewriter.replaceOp(op, newResults);
-    return success();
+      Operation *newOp = rewriter.clone(*op, mapping);
+      rewriter.modifyOpInPlace(newOp, [&]() {
+        for (OpResult res : newOp->getResults()) {
+          res.setType(targetType);
+        }
+      });
+      SmallVector<Value> newResults;
+      for (Value res : newOp->getResults())
+        newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+      rewriter.replaceOp(op, newResults);
+      return success();
+    }
+    return failure();
   }
 
 private:
   DataFlowSolver &solver;
-  unsigned targetBitwidth;
+  SmallVector<unsigned, 4> targetBitwidths;
 };
 
 struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
   NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
-             unsigned target)
-      : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+             ArrayRef<unsigned> target)
+      : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
+  }
 
   LogicalResult matchAndRewrite(arith::CmpIOp op,
                                 PatternRewriter &rewriter) const override {
     Value lhs = op.getLhs();
     Value rhs = op.getRhs();
 
-    Type srcType = checkArithType(lhs.getType(), targetBitwidth);
-    if (!srcType)
-      return failure();
-
     std::optional<ConstantIntRanges> range =
         getOperandsRange(solver, {lhs, rhs});
     if (!range)
       return failure();
 
-    auto smin = APInt::getSignedMinValue(targetBitwidth);
-    auto smax = APInt::getSignedMaxValue(targetBitwidth);
-    auto umin = APInt::getMinValue(targetBitwidth);
-    auto umax = APInt::getMaxValue(targetBitwidth);
-    if (!checkRange(*range, smin, smax, umin, umax))
-      return failure();
+    for (unsigned targetBitwidth : targetBitwidths) {
+      Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+      if (!srcType)
+        continue;
 
-    Type targetType = getTargetType(srcType, targetBitwidth);
-    if (targetType == srcType)
-      return failure();
+      auto smin = APInt::getSignedMinValue(targetBitwidth);
+      auto smax = APInt::getSignedMaxValue(targetBitwidth);
+      auto umin = APInt::getMinValue(targetBitwidth);
+      auto umax = APInt::getMaxValue(targetBitwidth);
+      if (!checkRange(*range, smin, smax, umin, umax))
+        continue;
 
-    Location loc = op->getLoc();
-    IRMapping mapping;
-    for (Value arg : op->getOperands()) {
-      Value newArg = doCast(rewriter, loc, arg, targetType);
-      mapping.map(arg, newArg);
-    }
+      Type targetType = getTargetType(srcType, targetBitwidth);
+      if (targetType == srcType)
+        continue;
 
-    Operation *newOp = rewriter.clone(*op, mapping);
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
+      Location loc = op->getLoc();
+      IRMapping mapping;
+      for (Value arg : op->getOperands()) {
+        Value newArg = doCast(rewriter, loc, arg, targetType);
+        mapping.map(arg, newArg);
+      }
+
+      Operation *newOp = rewriter.clone(*op, mapping);
+      rewriter.replaceOp(op, newOp->getResults());
+      return success();
+    }
+    return failure();
   }
 
 private:
   DataFlowSolver &solver;
-  unsigned targetBitwidth;
+  SmallVector<unsigned, 4> targetBitwidths;
 };
 
 struct IntRangeOptimizationsPass
@@ -453,7 +462,7 @@ struct IntRangeNarrowingPass
     DataFlowListener listener(solver);
 
     RewritePatternSet patterns(ctx);
-    populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+    populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
 
     GreedyRewriteConfig config;
     config.listener = &listener;
@@ -470,16 +479,16 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
 }
 
-void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
-                                                    DataFlowSolver &solver,
-                                                    unsigned targetBitwidth) {
+void mlir::arith::populateIntRangeNarrowingPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver,
+    ArrayRef<unsigned> bitwidthsSupported) {
   // Cmpi uses args ranges instead of results, run it with higher benefit,
   // as its argumens can be potentially replaced.
   patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
-                           targetBitwidth);
+                           bitwidthsSupported);
 
   patterns.add<NarrowElementwise>(patterns.getContext(), solver,
-                                  targetBitwidth);
+                                  bitwidthsSupported);
 }
 
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index d85cb3384061be..6283284567dce8 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32"  %s | FileCheck %s
+// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s
 
 // Do not truncate negative values
 // CHECK-LABEL: func @test_addi_neg

>From 170157fe1606a894310f66df012ff4aa25a5a757 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 13:12:00 +0200
Subject: [PATCH 04/19] update tests

---
 .../Dialect/Arith/int-range-narrowing.mlir    | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 6283284567dce8..8b73f02fd214a2 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s
+// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s
 
 // Do not truncate negative values
 // CHECK-LABEL: func @test_addi_neg
@@ -14,10 +14,10 @@ func.func @test_addi_neg() -> index {
 // CHECK-LABEL: func @test_addi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
-//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i8 to index
 //       CHECK:  return %[[RES_CASTED]] : index
 func.func @test_addi() -> index {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -30,10 +30,10 @@ func.func @test_addi() -> index {
 // CHECK-LABEL: func @test_addi_i64
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
-//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
-//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
-//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
-//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i8 to i64
 //       CHECK:  return %[[RES_CASTED]] : i64
 func.func @test_addi_i64() -> i64 {
   %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
@@ -45,9 +45,9 @@ func.func @test_addi_i64() -> i64 {
 // CHECK-LABEL: func @test_cmpi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
-//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
 //       CHECK:  return %[[RES]] : i1
 func.func @test_cmpi() -> i1 {
   %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index

>From 0af7e9e67cf04d4cdd3f5882a1992ef5566a74a0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 14:13:06 +0200
Subject: [PATCH 05/19] more tests

---
 .../Dialect/Arith/int-range-narrowing.mlir    | 142 ++++++++++++++++++
 1 file changed, 142 insertions(+)

diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 8b73f02fd214a2..cd0a4c449913e1 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// Some basic tests
+//===----------------------------------------------------------------------===//
+
 // Do not truncate negative values
 // CHECK-LABEL: func @test_addi_neg
 //       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
@@ -55,3 +59,141 @@ func.func @test_cmpi() -> i1 {
   %2 = arith.cmpi slt, %0, %1 : index
   return %2 : i1
 }
+
+//===----------------------------------------------------------------------===//
+// arith.addi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @addi_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @addi_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[ADD]] : i32
+func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because we cannot reduce the bitwidth
+// below i16, given the pass options set.
+//
+// CHECK-LABEL: func.func @addi_extsi_i16
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i16
+// CHECK-NEXT:    return %[[ADD]] : i16
+func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
+  %a = arith.extsi %lhs : i8 to i16
+  %b = arith.extsi %rhs : i8 to i16
+  %r = arith.addi %a, %b : i16
+  return %r : i16
+}
+
+//===----------------------------------------------------------------------===//
+// arith.subi
+//===----------------------------------------------------------------------===//
+
+// This patterns should only apply to `arith.subi` ops with sign-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @subi_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[SUB]] : i32
+func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @subi_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[ADD]] : i32
+func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
+//===----------------------------------------------------------------------===//
+// arith.muli
+//===----------------------------------------------------------------------===//
+
+// TODO: This should be optimized into i16
+// CHECK-LABEL: func.func @muli_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i24 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// We do not expect this case to be optimized because given n-bit operands,
+// arith.muli produces 2n bits of result.
+//
+// CHECK-LABEL: func.func @muli_extsi_i32
+// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT:    %[[RET:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
+  %a = arith.extsi %lhs : i16 to i32
+  %b = arith.extsi %rhs : i16 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @muli_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[MUL]] : i32
+func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}

>From 9616513bab14facd97e1302b07fd2df478912a34 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 18:06:43 +0200
Subject: [PATCH 06/19] remove the old pass

---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   5 -
 .../mlir/Dialect/Arith/Transforms/Passes.td   |  14 -
 .../Dialect/Arith/Transforms/CMakeLists.txt   |   1 -
 .../Dialect/Arith/Transforms/IntNarrowing.cpp | 790 --------------
 .../Arith/int-narrowing-invalid-options.mlir  |  16 -
 mlir/test/Dialect/Arith/int-narrowing.mlir    | 997 ------------------
 mlir/test/Dialect/Linalg/int-narrowing.mlir   | 147 ---
 7 files changed, 1970 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
 delete mode 100644 mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir
 delete mode 100644 mlir/test/Dialect/Arith/int-narrowing.mlir
 delete mode 100644 mlir/test/Dialect/Linalg/int-narrowing.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index b6a87e88c6efbb..58dce89fdb5786 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -82,11 +82,6 @@ void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
                                        ArrayRef<unsigned> bitwidthsSupported);
 
-// TODO: merge these two narrowing passes.
-/// Add patterns for integer bitwidth narrowing.
-void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
-                                       const ArithIntNarrowingOptions &options);
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 898d74249af618..98f90d120fa1c6 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -111,18 +111,4 @@ def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
   let dependentDialects = ["vector::VectorDialect"];
 }
 
-def ArithIntNarrowing : Pass<"arith-int-narrowing"> {
-  let summary = "Reduce integer operation bitwidth";
-  let description = [{
-    Reduce bitwidths of integer types used in arith operations. This pass
-    prefers the narrowest available integer bitwidths that are guaranteed to
-    produce the same results.
-  }];
-  let dependentDialects = ["vector::VectorDialect"];
-  let options = [
-    ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
-               "Integer bitwidths supported">,
-  ];
- }
-
 #endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 93a004d31916f5..912853871b7f8d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -6,7 +6,6 @@ add_mlir_dialect_library(MLIRArithTransforms
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
   ExpandOps.cpp
-  IntNarrowing.cpp
   IntRangeOptimizations.cpp
   ReifyValueBounds.cpp
   UnsignedWhenEquivalent.cpp
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
deleted file mode 100644
index b61218bb7f1af6..00000000000000
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ /dev/null
@@ -1,790 +0,0 @@
-//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/Transforms/Passes.h"
-
-#include "mlir/Analysis/Presburger/IntegerRelation.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Transforms/Transforms.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Matchers.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/ValueBoundsOpInterface.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include <cassert>
-#include <cstdint>
-
-namespace mlir::arith {
-#define GEN_PASS_DEF_ARITHINTNARROWING
-#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
-} // namespace mlir::arith
-
-namespace mlir::arith {
-namespace {
-//===----------------------------------------------------------------------===//
-// Common Helpers
-//===----------------------------------------------------------------------===//
-
-/// The base for integer bitwidth narrowing patterns.
-template <typename SourceOp>
-struct NarrowingPattern : OpRewritePattern<SourceOp> {
-  NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
-                   PatternBenefit benefit = 1)
-      : OpRewritePattern<SourceOp>(ctx, benefit),
-        supportedBitwidths(options.bitwidthsSupported.begin(),
-                           options.bitwidthsSupported.end()) {
-    assert(!supportedBitwidths.empty() && "Invalid options");
-    assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
-    llvm::sort(supportedBitwidths);
-  }
-
-  FailureOr<unsigned>
-  getNarrowestCompatibleBitwidth(unsigned bitsRequired) const {
-    for (unsigned candidate : supportedBitwidths)
-      if (candidate >= bitsRequired)
-        return candidate;
-
-    return failure();
-  }
-
-  /// Returns the narrowest supported type that fits `bitsRequired`.
-  FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const {
-    assert(origTy);
-    FailureOr<unsigned> bestBitwidth =
-        getNarrowestCompatibleBitwidth(bitsRequired);
-    if (failed(bestBitwidth))
-      return failure();
-
-    Type elemTy = getElementTypeOrSelf(origTy);
-    if (!isa<IntegerType>(elemTy))
-      return failure();
-
-    auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
-    if (newElemTy == elemTy)
-      return failure();
-
-    if (origTy == elemTy)
-      return newElemTy;
-
-    if (auto shapedTy = dyn_cast<ShapedType>(origTy))
-      if (dyn_cast<IntegerType>(shapedTy.getElementType()))
-        return shapedTy.clone(shapedTy.getShape(), newElemTy);
-
-    return failure();
-  }
-
-private:
-  // Supported integer bitwidths in the ascending order.
-  llvm::SmallVector<unsigned, 6> supportedBitwidths;
-};
-
-/// Returns the integer bitwidth required to represent `type`.
-FailureOr<unsigned> calculateBitsRequired(Type type) {
-  assert(type);
-  if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type)))
-    return intTy.getWidth();
-
-  return failure();
-}
-
-enum class ExtensionKind { Sign, Zero };
-
-/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
-/// the exact op type. Exposes helper functions to query the types, operands,
-/// and the result. This is so that we can handle both extension kinds without
-/// needing to use templates or branching.
-class ExtensionOp {
-public:
-  /// Attemps to create a new extension op from `op`. Returns an extension op
-  /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
-  /// otherwise.
-  static FailureOr<ExtensionOp> from(Operation *op) {
-    if (dyn_cast_or_null<arith::ExtSIOp>(op))
-      return ExtensionOp{op, ExtensionKind::Sign};
-    if (dyn_cast_or_null<arith::ExtUIOp>(op))
-      return ExtensionOp{op, ExtensionKind::Zero};
-
-    return failure();
-  }
-
-  ExtensionOp(const ExtensionOp &) = default;
-  ExtensionOp &operator=(const ExtensionOp &) = default;
-
-  /// Creates a new extension op of the same kind.
-  Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
-                      Value in) {
-    if (kind == ExtensionKind::Sign)
-      return rewriter.create<arith::ExtSIOp>(loc, newType, in);
-
-    return rewriter.create<arith::ExtUIOp>(loc, newType, in);
-  }
-
-  /// Replaces `toReplace` with a new extension op of the same kind.
-  void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
-                          Value in) {
-    assert(toReplace->getNumResults() == 1);
-    Type newType = toReplace->getResult(0).getType();
-    Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
-    rewriter.replaceOp(toReplace, newOp->getResult(0));
-  }
-
-  ExtensionKind getKind() { return kind; }
-
-  Value getResult() { return op->getResult(0); }
-  Value getIn() { return op->getOperand(0); }
-
-  Type getType() { return getResult().getType(); }
-  Type getElementType() { return getElementTypeOrSelf(getType()); }
-  Type getInType() { return getIn().getType(); }
-  Type getInElementType() { return getElementTypeOrSelf(getInType()); }
-
-private:
-  ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
-    assert(op);
-    assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
-  }
-  Operation *op = nullptr;
-  ExtensionKind kind = {};
-};
-
-/// Returns the integer bitwidth required to represent `value`.
-unsigned calculateBitsRequired(const APInt &value,
-                               ExtensionKind lookThroughExtension) {
-  // For unsigned values, we only need the active bits. As a special case, zero
-  // requires one bit.
-  if (lookThroughExtension == ExtensionKind::Zero)
-    return std::max(value.getActiveBits(), 1u);
-
-  // If a signed value is nonnegative, we need one extra bit for the sign.
-  if (value.isNonNegative())
-    return value.getActiveBits() + 1;
-
-  // For the signed min, we need all the bits.
-  if (value.isMinSignedValue())
-    return value.getBitWidth();
-
-  // For negative values, we need all the non-sign bits and one extra bit for
-  // the sign.
-  return value.getBitWidth() - value.getNumSignBits() + 1;
-}
-
-/// Returns the integer bitwidth required to represent `value`.
-/// Looks through either sign- or zero-extension as specified by
-/// `lookThroughExtension`.
-FailureOr<unsigned> calculateBitsRequired(Value value,
-                                          ExtensionKind lookThroughExtension) {
-  // Handle constants.
-  if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr))
-      return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
-
-    if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
-      if (elemsAttr.getElementType().isIntOrIndex()) {
-        if (elemsAttr.isSplat())
-          return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
-                                       lookThroughExtension);
-
-        unsigned maxBits = 1;
-        for (const APInt &elemValue : elemsAttr.getValues<APInt>())
-          maxBits = std::max(
-              maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
-        return maxBits;
-      }
-    }
-  }
-
-  if (lookThroughExtension == ExtensionKind::Sign) {
-    if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
-      return calculateBitsRequired(sext.getIn().getType());
-  } else if (lookThroughExtension == ExtensionKind::Zero) {
-    if (auto zext = value.getDefiningOp<arith::ExtUIOp>())
-      return calculateBitsRequired(zext.getIn().getType());
-  }
-
-  // If nothing else worked, return the type requirements for this element type.
-  return calculateBitsRequired(value.getType());
-}
-
-/// Base pattern for arith binary ops.
-/// Example:
-/// ```
-///   %lhs = arith.extsi %a : i8 to i32
-///   %rhs = arith.extsi %b : i8 to i32
-///   %r = arith.addi %lhs, %rhs : i32
-/// ==>
-///   %lhs = arith.extsi %a : i8 to i16
-///   %rhs = arith.extsi %b : i8 to i16
-///   %add = arith.addi %lhs, %rhs : i16
-///   %r = arith.extsi %add : i16 to i32
-/// ```
-template <typename BinaryOp>
-struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
-  using NarrowingPattern<BinaryOp>::NarrowingPattern;
-
-  /// Returns the number of bits required to represent the full result, assuming
-  /// that both operands are `operandBits`-wide. Derived classes must implement
-  /// this, taking into account `BinaryOp` semantics.
-  virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
-
-  /// Customization point for patterns that should only apply with
-  /// zero/sign-extension ops as arguments.
-  virtual bool isSupported(ExtensionOp) const { return true; }
-
-  LogicalResult matchAndRewrite(BinaryOp op,
-                                PatternRewriter &rewriter) const final {
-    Type origTy = op.getType();
-    FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
-    if (failed(resultBits))
-      return failure();
-
-    // For the optimization to apply, we expect the lhs to be an extension op,
-    // and for the rhs to either be the same extension op or a constant.
-    FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
-    if (failed(ext) || !isSupported(*ext))
-      return failure();
-
-    FailureOr<unsigned> lhsBitsRequired =
-        calculateBitsRequired(ext->getIn(), ext->getKind());
-    if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
-      return failure();
-
-    FailureOr<unsigned> rhsBitsRequired =
-        calculateBitsRequired(op.getRhs(), ext->getKind());
-    if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
-      return failure();
-
-    // Negotiate a common bit requirements for both lhs and rhs, accounting for
-    // the result requiring more bits than the operands.
-    unsigned commonBitsRequired =
-        getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
-    FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
-    if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
-      return failure();
-
-    Location loc = op.getLoc();
-    Value newLhs =
-        rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
-    Value newRhs =
-        rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
-    Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
-    ext->recreateAndReplace(rewriter, op, newAdd);
-    return success();
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// AddIOp Pattern
-//===----------------------------------------------------------------------===//
-
-struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
-  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
-
-  // Addition may require one extra bit for the result.
-  // Example: `UINT8_MAX + 1 == 255 + 1 == 256`.
-  unsigned getResultBitsProduced(unsigned operandBits) const override {
-    return operandBits + 1;
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// SubIOp Pattern
-//===----------------------------------------------------------------------===//
-
-struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
-  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
-
-  // This optimization only applies to signed arguments.
-  bool isSupported(ExtensionOp ext) const override {
-    return ext.getKind() == ExtensionKind::Sign;
-  }
-
-  // Subtraction may require one extra bit for the result.
-  // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`.
-  unsigned getResultBitsProduced(unsigned operandBits) const override {
-    return operandBits + 1;
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// MulIOp Pattern
-//===----------------------------------------------------------------------===//
-
-struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
-  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
-
-  // Multiplication may require up double the operand bits.
-  // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`.
-  unsigned getResultBitsProduced(unsigned operandBits) const override {
-    return 2 * operandBits;
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// DivSIOp Pattern
-//===----------------------------------------------------------------------===//
-
-struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
-  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
-
-  // This optimization only applies to signed arguments.
-  bool isSupported(ExtensionOp ext) const override {
-    return ext.getKind() == ExtensionKind::Sign;
-  }
-
-  // Unlike multiplication, signed division requires only one more result bit.
-  // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`.
-  unsigned getResultBitsProduced(unsigned operandBits) const override {
-    return operandBits + 1;
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// DivUIOp Pattern
-//===----------------------------------------------------------------------===//
-
-struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
-  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
-
-  // This optimization only applies to unsigned arguments.
-  bool isSupported(ExtensionOp ext) const override {
-    return ext.getKind() == ExtensionKind::Zero;
-  }
-
-  // Unsigned division does not require any extra result bits.
-  unsigned getResultBitsProduced(unsigned operandBits) const override {
-    return operandBits;
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// Min/Max Patterns
-//===----------------------------------------------------------------------===//
-
-template <typename MinMaxOp, ExtensionKind Kind>
-struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
-  using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
-
-  bool isSupported(ExtensionOp ext) const override {
-    return ext.getKind() == Kind;
-  }
-
-  // Min/max returns one of the arguments and does not require any extra result
-  // bits.
-  unsigned getResultBitsProduced(unsigned operandBits) const override {
-    return operandBits;
-  }
-};
-using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
-using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
-using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
-using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
-
-//===----------------------------------------------------------------------===//
-// *IToFPOp Patterns
-//===----------------------------------------------------------------------===//
-
-template <typename IToFPOp, ExtensionKind Extension>
-struct IToFPPattern final : NarrowingPattern<IToFPOp> {
-  using NarrowingPattern<IToFPOp>::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(IToFPOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<unsigned> narrowestWidth =
-        calculateBitsRequired(op.getIn(), Extension);
-    if (failed(narrowestWidth))
-      return failure();
-
-    FailureOr<Type> narrowTy =
-        this->getNarrowType(*narrowestWidth, op.getIn().getType());
-    if (failed(narrowTy))
-      return failure();
-
-    Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
-                                                         op.getIn());
-    rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
-    return success();
-  }
-};
-using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
-using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
-
-//===----------------------------------------------------------------------===//
-// Index Cast Patterns
-//===----------------------------------------------------------------------===//
-
-// These rely on the `ValueBounds` interface for index values. For example, we
-// can often statically tell index value bounds of loop induction variables.
-
-template <typename CastOp, ExtensionKind Kind>
-struct IndexCastPattern final : NarrowingPattern<CastOp> {
-  using NarrowingPattern<CastOp>::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(CastOp op,
-                                PatternRewriter &rewriter) const override {
-    Value in = op.getIn();
-    // We only support scalar index -> integer casts.
-    if (!isa<IndexType>(in.getType()))
-      return failure();
-
-    // Check the lower bound in both the signed and unsigned cast case. We
-    // conservatively assume that even unsigned casts may be performed on
-    // negative indices.
-    FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::LB, in);
-    if (failed(lb))
-      return failure();
-
-    FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::UB, in,
-        /*stopCondition=*/nullptr, /*closedUB=*/true);
-    if (failed(ub))
-      return failure();
-
-    assert(*lb <= *ub && "Invalid bounds");
-    unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind);
-    unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind);
-    unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired);
-
-    IntegerType resultTy = cast<IntegerType>(op.getType());
-    if (resultTy.getWidth() <= bitsRequired)
-      return failure();
-
-    FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
-    if (failed(narrowTy))
-      return failure();
-
-    Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
-
-    if (Kind == ExtensionKind::Sign)
-      rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
-    else
-      rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
-    return success();
-  }
-};
-using IndexCastSIPattern =
-    IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
-using IndexCastUIPattern =
-    IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
-
-//===----------------------------------------------------------------------===//
-// Patterns to Commute Extension Ops
-//===----------------------------------------------------------------------===//
-
-struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::BroadcastOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getSource().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    VectorType origTy = op.getResultVectorType();
-    VectorType newTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
-    Value newBroadcast =
-        rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
-    ext->recreateAndReplace(rewriter, op, newBroadcast);
-    return success();
-  }
-};
-
-struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::ExtractOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getVector().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    Value newExtract = rewriter.create<vector::ExtractOp>(
-        op.getLoc(), ext->getIn(), op.getMixedPosition());
-    ext->recreateAndReplace(rewriter, op, newExtract);
-    return success();
-  }
-};
-
-struct ExtensionOverExtractElement final
-    : NarrowingPattern<vector::ExtractElementOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::ExtractElementOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getVector().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    Value newExtract = rewriter.create<vector::ExtractElementOp>(
-        op.getLoc(), ext->getIn(), op.getPosition());
-    ext->recreateAndReplace(rewriter, op, newExtract);
-    return success();
-  }
-};
-
-struct ExtensionOverExtractStridedSlice final
-    : NarrowingPattern<vector::ExtractStridedSliceOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getVector().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    VectorType origTy = op.getType();
-    VectorType extractTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
-    Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
-        op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
-        op.getStrides());
-    ext->recreateAndReplace(rewriter, op, newExtract);
-    return success();
-  }
-};
-
-/// Base pattern for `vector.insert` narrowing patterns.
-template <typename InsertionOp>
-struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
-  using NarrowingPattern<InsertionOp>::NarrowingPattern;
-
-  /// Derived classes must provide a function to create the matching insertion
-  /// op based on the original op and new arguments.
-  virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
-                                        InsertionOp origInsert,
-                                        Value narrowValue,
-                                        Value narrowDest) const = 0;
-
-  LogicalResult matchAndRewrite(InsertionOp op,
-                                PatternRewriter &rewriter) const final {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getSource().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
-    if (failed(newInsert))
-      return failure();
-    ext->recreateAndReplace(rewriter, op, *newInsert);
-    return success();
-  }
-
-  FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
-                                            PatternRewriter &rewriter,
-                                            ExtensionOp insValue) const {
-    // Calculate the operand and result bitwidths. We can only apply narrowing
-    // when the inserted source value and destination vector require fewer bits
-    // than the result. Because the source and destination may have different
-    // bitwidths requirements, we have to find the common narrow bitwidth that
-    // is greater equal to the operand bitwidth requirements and still narrower
-    // than the result.
-    FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
-    if (failed(origBitsRequired))
-      return failure();
-
-    // TODO: We could relax this check by disregarding bitwidth requirements of
-    // elements that we know will be replaced by the insertion.
-    FailureOr<unsigned> destBitsRequired =
-        calculateBitsRequired(op.getDest(), insValue.getKind());
-    if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
-      return failure();
-
-    FailureOr<unsigned> insertedBitsRequired =
-        calculateBitsRequired(insValue.getIn(), insValue.getKind());
-    if (failed(insertedBitsRequired) ||
-        *insertedBitsRequired >= *origBitsRequired)
-      return failure();
-
-    // Find a narrower element type that satisfies the bitwidth requirements of
-    // both the source and the destination values.
-    unsigned newInsertionBits =
-        std::max(*destBitsRequired, *insertedBitsRequired);
-    FailureOr<Type> newVecTy =
-        this->getNarrowType(newInsertionBits, op.getType());
-    if (failed(newVecTy) || *newVecTy == op.getType())
-      return failure();
-
-    FailureOr<Type> newInsertedValueTy =
-        this->getNarrowType(newInsertionBits, insValue.getType());
-    if (failed(newInsertedValueTy))
-      return failure();
-
-    Location loc = op.getLoc();
-    Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
-        loc, *newInsertedValueTy, insValue.getResult());
-    Value narrowDest =
-        rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
-    return createInsertionOp(rewriter, op, narrowValue, narrowDest);
-  }
-};
-
-struct ExtensionOverInsert final
-    : ExtensionOverInsertionPattern<vector::InsertOp> {
-  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
-
-  vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
-                                     vector::InsertOp origInsert,
-                                     Value narrowValue,
-                                     Value narrowDest) const override {
-    return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
-                                             narrowDest,
-                                             origInsert.getMixedPosition());
-  }
-};
-
-struct ExtensionOverInsertElement final
-    : ExtensionOverInsertionPattern<vector::InsertElementOp> {
-  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
-
-  vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
-                                            vector::InsertElementOp origInsert,
-                                            Value narrowValue,
-                                            Value narrowDest) const override {
-    return rewriter.create<vector::InsertElementOp>(
-        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
-  }
-};
-
-struct ExtensionOverInsertStridedSlice final
-    : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
-  using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
-
-  vector::InsertStridedSliceOp
-  createInsertionOp(PatternRewriter &rewriter,
-                    vector::InsertStridedSliceOp origInsert, Value narrowValue,
-                    Value narrowDest) const override {
-    return rewriter.create<vector::InsertStridedSliceOp>(
-        origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
-        origInsert.getStrides());
-  }
-};
-
-struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getSource().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    VectorType origTy = op.getResultVectorType();
-    VectorType newTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
-    Value newCast =
-        rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
-    ext->recreateAndReplace(rewriter, op, newCast);
-    return success();
-  }
-};
-
-struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getVector().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    VectorType origTy = op.getResultVectorType();
-    VectorType newTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
-    Value newTranspose = rewriter.create<vector::TransposeOp>(
-        op.getLoc(), newTy, ext->getIn(), op.getPermutation());
-    ext->recreateAndReplace(rewriter, op, newTranspose);
-    return success();
-  }
-};
-
-struct ExtensionOverFlatTranspose final
-    : NarrowingPattern<vector::FlatTransposeOp> {
-  using NarrowingPattern::NarrowingPattern;
-
-  LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    FailureOr<ExtensionOp> ext =
-        ExtensionOp::from(op.getMatrix().getDefiningOp());
-    if (failed(ext))
-      return failure();
-
-    VectorType origTy = op.getType();
-    VectorType newTy =
-        origTy.cloneWith(origTy.getShape(), ext->getInElementType());
-    Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
-        op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
-        op.getColumnsAttr());
-    ext->recreateAndReplace(rewriter, op, newTranspose);
-    return success();
-  }
-};
-
-//===----------------------------------------------------------------------===//
-// Pass Definitions
-//===----------------------------------------------------------------------===//
-
-struct ArithIntNarrowingPass final
-    : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
-  using ArithIntNarrowingBase::ArithIntNarrowingBase;
-
-  void runOnOperation() override {
-    if (bitwidthsSupported.empty() ||
-        llvm::is_contained(bitwidthsSupported, 0)) {
-      // Invalid pass options.
-      return signalPassFailure();
-    }
-
-    Operation *op = getOperation();
-    MLIRContext *ctx = op->getContext();
-    RewritePatternSet patterns(ctx);
-    populateArithIntNarrowingPatterns(
-        patterns, ArithIntNarrowingOptions{
-                      llvm::to_vector_of<unsigned>(bitwidthsSupported)});
-    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
-      signalPassFailure();
-  }
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Public API
-//===----------------------------------------------------------------------===//
-
-void populateArithIntNarrowingPatterns(
-    RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
-  // Add commute patterns with a higher benefit. This is to expose more
-  // optimization opportunities to narrowing patterns.
-  patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
-               ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
-               ExtensionOverInsert, ExtensionOverInsertElement,
-               ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
-               ExtensionOverTranspose, ExtensionOverFlatTranspose>(
-      patterns.getContext(), options, PatternBenefit(2));
-
-  patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
-               DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
-               MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
-               IndexCastUIPattern>(patterns.getContext(), options);
-}
-
-} // namespace mlir::arith
diff --git a/mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir b/mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir
deleted file mode 100644
index 0e34108973b4c9..00000000000000
--- a/mlir/test/Dialect/Arith/int-narrowing-invalid-options.mlir
+++ /dev/null
@@ -1,16 +0,0 @@
-// RUN: not mlir-opt %s --arith-int-narrowing --mlir-print-ir-after-failure 2>&1 \
-// RUN:   | FileCheck %s
-
-// RUN: not mlir-opt %s --arith-int-narrowing="int-bitwidths-supported=0" \
-// RUN:   --mlir-print-ir-after-failure 2>&1 | FileCheck %s
-
-// Make sure we do not crash on invalid pass options.
-
-// CHECK:       IR Dump After ArithIntNarrowing Failed (arith-int-narrowing)
-// CHECK-LABEL: func.func @addi_extsi_i8
-func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.addi %a, %b : i32
-  return %r : i32
-}
diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
deleted file mode 100644
index 153c0a85762628..00000000000000
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ /dev/null
@@ -1,997 +0,0 @@
-// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,24,32" \
-// RUN:          --verify-diagnostics %s | FileCheck %s
-
-//===----------------------------------------------------------------------===//
-// arith.addi
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @addi_extsi_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i16 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.addi %a, %b : i32
-  return %r : i32
-}
-
-// CHECK-LABEL: func.func @addi_extui_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[ADD]] : i16 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.addi %a, %b : i32
-  return %r : i32
-}
-
-// arith.addi produces one more bit of result than the operand bitwidth.
-//
-// CHECK-LABEL: func.func @addi_extsi_i24
-// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i24
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i24 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @addi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
-  %a = arith.extsi %lhs : i16 to i32
-  %b = arith.extsi %rhs : i16 to i32
-  %r = arith.addi %a, %b : i32
-  return %r : i32
-}
-
-// This case should not get optimized because of mixed extensions.
-//
-// CHECK-LABEL: func.func @addi_mixed_ext_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[ADD]] : i32
-func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.addi %a, %b : i32
-  return %r : i32
-}
-
-// This case should not get optimized because we cannot reduce the bitwidth
-// below i16, given the pass options set.
-//
-// CHECK-LABEL: func.func @addi_extsi_i16
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i16
-// CHECK-NEXT:    return %[[ADD]] : i16
-func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
-  %a = arith.extsi %lhs : i8 to i16
-  %b = arith.extsi %rhs : i8 to i16
-  %r = arith.addi %a, %b : i16
-  return %r : i16
-}
-
-// CHECK-LABEL: func.func @addi_extsi_3xi8_cst
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi8>)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
-// CHECK-NEXT:    %[[EXT:.+]]  = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[CST]] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
-  %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
-  %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
-  %r = arith.addi %a, %cst : vector<3xi32>
-  return %r : vector<3xi32>
-}
-
-//===----------------------------------------------------------------------===//
-// arith.subi
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @subi_extsi_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[SUB]] : i16 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @subi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.subi %a, %b : i32
-  return %r : i32
-}
-
-// This patterns should only apply to `arith.subi` ops with sign-extended
-// arguments.
-//
-// CHECK-LABEL: func.func @subi_extui_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[SUB]] : i32
-func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.subi %a, %b : i32
-  return %r : i32
-}
-
-// This case should not get optimized because of mixed extensions.
-//
-// CHECK-LABEL: func.func @subi_mixed_ext_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[ADD]] : i32
-func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.subi %a, %b : i32
-  return %r : i32
-}
-
-// arith.subi produces one more bit of result than the operand bitwidth.
-//
-// CHECK-LABEL: func.func @subi_extsi_i24
-// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i24
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i24 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @subi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
-  %a = arith.extsi %lhs : i16 to i32
-  %b = arith.extsi %rhs : i16 to i32
-  %r = arith.subi %a, %b : i32
-  return %r : i32
-}
-
-//===----------------------------------------------------------------------===//
-// arith.muli
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @muli_extsi_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[MUL]] : i16 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @muli_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.muli %a, %b : i32
-  return %r : i32
-}
-
-// CHECK-LABEL: func.func @muli_extui_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i16 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.muli %a, %b : i32
-  return %r : i32
-}
-
-// We do not expect this case to be optimized because given n-bit operands,
-// arith.muli produces 2n bits of result.
-//
-// CHECK-LABEL: func.func @muli_extsi_i32
-// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.extsi %[[ARG0]] : i16 to i32
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.extsi %[[ARG1]] : i16 to i32
-// CHECK-NEXT:    %[[RET:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
-  %a = arith.extsi %lhs : i16 to i32
-  %b = arith.extsi %rhs : i16 to i32
-  %r = arith.muli %a, %b : i32
-  return %r : i32
-}
-
-// This case should not get optimized because of mixed extensions.
-//
-// CHECK-LABEL: func.func @muli_mixed_ext_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[MUL]] : i32
-func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.muli %a, %b : i32
-  return %r : i32
-}
-
-// CHECK-LABEL: func.func @muli_extsi_3xi8_cst
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi8>)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
-// CHECK-NEXT:    %[[EXT:.+]]  = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
-// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[CST]] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[MUL]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
-  %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
-  %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
-  %r = arith.muli %a, %cst : vector<3xi32>
-  return %r : vector<3xi32>
-}
-
-//===----------------------------------------------------------------------===//
-// arith.divsi
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @divsi_extsi_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.divsi %[[LHS]], %[[RHS]] : i16
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[SUB]] : i16 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @divsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.divsi %a, %b : i32
-  return %r : i32
-}
-
-// This patterns should only apply to `arith.divsi` ops with sign-extended
-// arguments.
-//
-// CHECK-LABEL: func.func @divsi_extui_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.divsi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[SUB]] : i32
-func.func @divsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.divsi %a, %b : i32
-  return %r : i32
-}
-
-// arith.divsi produces one more bit of result than the operand bitwidth.
-//
-// CHECK-LABEL: func.func @divsi_extsi_i24
-// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
-// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
-// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
-// CHECK-NEXT:    %[[ADD:.+]]  = arith.divsi %[[LHS]], %[[RHS]] : i24
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i24 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @divsi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
-  %a = arith.extsi %lhs : i16 to i32
-  %b = arith.extsi %rhs : i16 to i32
-  %r = arith.divsi %a, %b : i32
-  return %r : i32
-}
-
-//===----------------------------------------------------------------------===//
-// arith.divui
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @divui_extui_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.divui %[[ARG0]], %[[ARG1]] : i8
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[SUB]] : i8 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @divui_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.divui %a, %b : i32
-  return %r : i32
-}
-
-// This patterns should only apply to `arith.divui` ops with zero-extended
-// arguments.
-//
-// CHECK-LABEL: func.func @divui_extsi_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[SUB:.+]]  = arith.divui %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[SUB]] : i32
-func.func @divui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.divui %a, %b : i32
-  return %r : i32
-}
-
-//===----------------------------------------------------------------------===//
-// arith.*itofp
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @sitofp_extsi_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i16)
-// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : i16 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @sitofp_extsi_i16(%a: i16) -> f16 {
-  %b = arith.extsi %a : i16 to i32
-  %f = arith.sitofp %b : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @sitofp_extsi_vector_i16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
-// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : vector<3xi16> to vector<3xf16>
-// CHECK-NEXT:    return %[[RET]] : vector<3xf16>
-func.func @sitofp_extsi_vector_i16(%a: vector<3xi16>) -> vector<3xf16> {
-  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
-  %f = arith.sitofp %b : vector<3xi32> to vector<3xf16>
-  return %f : vector<3xf16>
-}
-
-// CHECK-LABEL: func.func @sitofp_extsi_tensor_i16
-// CHECK-SAME:    (%[[ARG:.+]]: tensor<3x?xi16>)
-// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : tensor<3x?xi16> to tensor<3x?xf16>
-// CHECK-NEXT:    return %[[RET]] : tensor<3x?xf16>
-func.func @sitofp_extsi_tensor_i16(%a: tensor<3x?xi16>) -> tensor<3x?xf16> {
-  %b = arith.extsi %a : tensor<3x?xi16> to tensor<3x?xi32>
-  %f = arith.sitofp %b : tensor<3x?xi32> to tensor<3x?xf16>
-  return %f : tensor<3x?xf16>
-}
-
-// Narrowing to i64 is not enabled in pass options.
-//
-// CHECK-LABEL: func.func @sitofp_extsi_i64
-// CHECK-SAME:    (%[[ARG:.+]]: i64)
-// CHECK-NEXT:    %[[EXT:.+]] = arith.extsi %[[ARG]] : i64 to i128
-// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[EXT]] : i128 to f32
-// CHECK-NEXT:    return %[[RET]] : f32
-func.func @sitofp_extsi_i64(%a: i64) -> f32 {
-  %b = arith.extsi %a : i64 to i128
-  %f = arith.sitofp %b : i128 to f32
-  return %f : f32
-}
-
-// CHECK-LABEL: func.func @uitofp_extui_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i16)
-// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[ARG]] : i16 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @uitofp_extui_i16(%a: i16) -> f16 {
-  %b = arith.extui %a : i16 to i32
-  %f = arith.uitofp %b : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @sitofp_extsi_extsi_i8
-// CHECK-SAME:    (%[[ARG:.+]]: i8)
-// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[ARG]] : i8 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @sitofp_extsi_extsi_i8(%a: i8) -> f16 {
-  %b = arith.extsi %a : i8 to i16
-  %c = arith.extsi %b : i16 to i32
-  %f = arith.sitofp %c : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @uitofp_extui_extui_i8
-// CHECK-SAME:    (%[[ARG:.+]]: i8)
-// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[ARG]] : i8 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @uitofp_extui_extui_i8(%a: i8) -> f16 {
-  %b = arith.extui %a : i8 to i16
-  %c = arith.extui %b : i16 to i32
-  %f = arith.uitofp %c : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @uitofp_extsi_extui_i8
-// CHECK-SAME:    (%[[ARG:.+]]: i8)
-// CHECK-NEXT:    %[[EXT:.+]] = arith.extsi %[[ARG]] : i8 to i16
-// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[EXT]] : i16 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @uitofp_extsi_extui_i8(%a: i8) -> f16 {
-  %b = arith.extsi %a : i8 to i16
-  %c = arith.extui %b : i16 to i32
-  %f = arith.uitofp %c : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @uitofp_trunci_extui_i8
-// CHECK-SAME:    (%[[ARG:.+]]: i16)
-// CHECK-NEXT:    %[[TR:.+]]  = arith.trunci %[[ARG]] : i16 to i8
-// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[TR]] : i8 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @uitofp_trunci_extui_i8(%a: i16) -> f16 {
-  %b = arith.trunci %a : i16 to i8
-  %c = arith.extui %b : i8 to i32
-  %f = arith.uitofp %c : i32 to f16
-  return %f : f16
-}
-
-// This should not be folded because arith.extui changes the signed
-// range of the number. For example:
-//  extsi -1 : i16 to i32 ==> -1
-//  extui -1 : i16 to i32 ==> U16_MAX
-//
-/// CHECK-LABEL: func.func @sitofp_extui_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i16)
-// CHECK-NEXT:    %[[EXT:.+]] = arith.extui %[[ARG]] : i16 to i32
-// CHECK-NEXT:    %[[RET:.+]] = arith.sitofp %[[EXT]] : i32 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @sitofp_extui_i16(%a: i16) -> f16 {
-  %b = arith.extui %a : i16 to i32
-  %f = arith.sitofp %b : i32 to f16
-  return %f : f16
-}
-
-// This should not be folded because arith.extsi changes the unsigned
-// range of the number. For example:
-//  extsi -1 : i16 to i32 ==> U32_MAX
-//  extui -1 : i16 to i32 ==> U16_MAX
-//
-// CHECK-LABEL: func.func @uitofp_extsi_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i16)
-// CHECK-NEXT:    %[[EXT:.+]] = arith.extsi %[[ARG]] : i16 to i32
-// CHECK-NEXT:    %[[RET:.+]] = arith.uitofp %[[EXT]] : i32 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @uitofp_extsi_i16(%a: i16) -> f16 {
-  %b = arith.extsi %a : i16 to i32
-  %f = arith.uitofp %b : i32 to f16
-  return %f : f16
-}
-
-//===----------------------------------------------------------------------===//
-// arith.maxsi
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @maxsi_extsi_i8
-// CHECK-SAME:    (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
-// CHECK-NEXT:    %[[MAX:.+]]  = arith.maxsi %[[LHS]], %[[RHS]] : i8
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[MAX]] : i8 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @maxsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.maxsi %a, %b : i32
-  return %r : i32
-}
-
-// This patterns should only apply to `arith.maxsi` ops with sign-extended
-// arguments.
-//
-// CHECK-LABEL: func.func @maxsi_extui_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[MAX:.+]]  = arith.maxsi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[MAX]] : i32
-func.func @maxsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.maxsi %a, %b : i32
-  return %r : i32
-}
-
-//===----------------------------------------------------------------------===//
-// arith.maxui
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @maxui_extui_i8
-// CHECK-SAME:    (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
-// CHECK-NEXT:    %[[MAX:.+]]  = arith.maxui %[[LHS]], %[[RHS]] : i8
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MAX]] : i8 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @maxui_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.maxui %a, %b : i32
-  return %r : i32
-}
-
-// This patterns should only apply to `arith.maxsi` ops with zero-extended
-// arguments.
-//
-// CHECK-LABEL: func.func @maxui_extsi_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[MAX:.+]]  = arith.maxui %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[MAX]] : i32
-func.func @maxui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.maxui %a, %b : i32
-  return %r : i32
-}
-
-//===----------------------------------------------------------------------===//
-// arith.minsi
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @minsi_extsi_i8
-// CHECK-SAME:    (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
-// CHECK-NEXT:    %[[min:.+]]  = arith.minsi %[[LHS]], %[[RHS]] : i8
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[min]] : i8 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @minsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.minsi %a, %b : i32
-  return %r : i32
-}
-
-// This patterns should only apply to `arith.minsi` ops with sign-extended
-// arguments.
-//
-// CHECK-LABEL: func.func @minsi_extui_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[min:.+]]  = arith.minsi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[min]] : i32
-func.func @minsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.minsi %a, %b : i32
-  return %r : i32
-}
-
-//===----------------------------------------------------------------------===//
-// arith.minui
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @minui_extui_i8
-// CHECK-SAME:    (%[[LHS:.+]]: i8, %[[RHS:.+]]: i8)
-// CHECK-NEXT:    %[[min:.+]]  = arith.minui %[[LHS]], %[[RHS]] : i8
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[min]] : i8 to i32
-// CHECK-NEXT:    return %[[RET]] : i32
-func.func @minui_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extui %lhs : i8 to i32
-  %b = arith.extui %rhs : i8 to i32
-  %r = arith.minui %a, %b : i32
-  return %r : i32
-}
-
-// This patterns should only apply to `arith.minsi` ops with zero-extended
-// arguments.
-//
-// CHECK-LABEL: func.func @minui_extsi_i8
-// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
-// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
-// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
-// CHECK-NEXT:    %[[min:.+]]  = arith.minui %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT:    return %[[min]] : i32
-func.func @minui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
-  %a = arith.extsi %lhs : i8 to i32
-  %b = arith.extsi %rhs : i8 to i32
-  %r = arith.minui %a, %b : i32
-  return %r : i32
-}
-
-//===----------------------------------------------------------------------===//
-// Commute Extension over Vector Ops
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func.func @extsi_over_extract_3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract %[[ARG]][1] : i16 from vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.sitofp %[[EXTR]] : i16 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @extsi_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
-  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
-  %c = vector.extract %b[1] : i32 from vector<3xi32>
-  %f = arith.sitofp %c : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @extui_over_extract_3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract %[[ARG]][1] : i16 from vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.uitofp %[[EXTR]] : i16 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @extui_over_extract_3xi16(%a: vector<3xi16>) -> f16 {
-  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
-  %c = vector.extract %b[1] : i32 from vector<3xi32>
-  %f = arith.uitofp %c : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @extsi_over_extractelement_3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.sitofp %[[EXTR]] : i16 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @extsi_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
-  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
-  %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
-  %f = arith.sitofp %c : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @extui_over_extractelement_3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>, %[[POS:.+]]: i32)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extractelement %[[ARG]][%[[POS]] : i32] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.uitofp %[[EXTR]] : i16 to f16
-// CHECK-NEXT:    return %[[RET]] : f16
-func.func @extui_over_extractelement_3xi16(%a: vector<3xi16>, %pos: i32) -> f16 {
-  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
-  %c = vector.extractelement %b[%pos : i32] : vector<3xi32>
-  %f = arith.uitofp %c : i32 to f16
-  return %f : f16
-}
-
-// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_1d
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[EXTR]] : vector<2xi16> to vector<2xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<2xi32>
-func.func @extsi_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
-  %b = arith.extsi %a : vector<3xi16> to vector<3xi32>
-  %c = vector.extract_strided_slice %b
-   {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
-  return %c : vector<2xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_extract_strided_slice_1d
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %[[ARG]] {offsets = [1], sizes = [2], strides = [1]} : vector<3xi16> to vector<2xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[EXTR]] : vector<2xi16> to vector<2xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<2xi32>
-func.func @extui_over_extract_strided_slice_1d(%a: vector<3xi16>) -> vector<2xi32> {
-  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
-  %c = vector.extract_strided_slice %b
-   {offsets = [1], sizes = [2], strides = [1]} : vector<3xi32> to vector<2xi32>
-  return %c : vector<2xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_extract_strided_slice_2d
-// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<1x2xi32>
-func.func @extsi_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
-  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
-  %c = vector.extract_strided_slice %b
-   {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
-  return %c : vector<1x2xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_extract_strided_slice_2d
-// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
-// CHECK-NEXT:    %[[EXTR:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi16> to vector<1x2xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[EXTR]] : vector<1x2xi16> to vector<1x2xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<1x2xi32>
-func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x2xi32> {
-  %b = arith.extui %a : vector<2x3xi16> to vector<2x3xi32>
-  %c = vector.extract_strided_slice %b
-   {offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
-  return %c : vector<1x2xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insert_3xi16
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16)
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> {
-  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
-  %d = arith.extsi %b : i16 to i32
-  %e = vector.insert %d, %c [1] : i32 into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_insert_3xi16
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16)
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extui_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> {
-  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
-  %d = arith.extui %b : i16 to i32
-  %e = vector.insert %d, %c [1] : i32 into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_0
-// CHECK-SAME:    (%[[ARG:.+]]: i16)
-// CHECK-NEXT:    %[[CST:.+]] = arith.constant dense<0> : vector<3xi16>
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i16 into vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_insert_3xi16_cst_0(%a: i16) -> vector<3xi32> {
-  %cst = arith.constant dense<0> : vector<3xi32>
-  %d = arith.extsi %a : i16 to i32
-  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insert_3xi8_cst
-// CHECK-SAME:    (%[[ARG:.+]]: i8)
-// CHECK-NEXT:    %[[CST:.+]] = arith.constant dense<[-1, 127, -128]> : vector<3xi8>
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi8> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> {
-  %cst = arith.constant dense<[-1, 127, -128]> : vector<3xi32>
-  %d = arith.extsi %a : i8 to i32
-  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_insert_3xi8_cst
-// CHECK-SAME:    (%[[ARG:.+]]: i8)
-// CHECK-NEXT:    %[[CST:.+]] = arith.constant dense<[1, 127, -1]> : vector<3xi8>
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi8> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extui_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> {
-  %cst = arith.constant dense<[1, 127, 255]> : vector<3xi32>
-  %d = arith.extui %a : i8 to i32
-  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i8)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
-// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
-// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
-// CHECK-NEXT:    %[[INS:.+]]  = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
-  %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
-  %d = arith.extsi %a : i8 to i32
-  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_insert_3xi16_cst_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i8)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[1, 256, 0]> : vector<3xi16>
-// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
-// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
-// CHECK-NEXT:    %[[INS:.+]]  = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
-  %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
-  %d = arith.extui %a : i8 to i32
-  %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
-// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
-  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
-  %d = arith.extsi %b : i16 to i32
-  %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_insertelement_3xi16
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
-// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extui_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
-  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
-  %d = arith.extui %b : i16 to i32
-  %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16_cst_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
-// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
-// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
-// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
-  %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
-  %d = arith.extsi %a : i8 to i32
-  %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_insertelement_3xi16_cst_i16
-// CHECK-SAME:    (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[1, 256, 0]> : vector<3xi16>
-// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
-// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
-// CHECK-NEXT:    %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extui_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
-  %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
-  %d = arith.extui %a : i8 to i32
-  %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_1d
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
-// CHECK-SAME:                    {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
-  %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
-  %d = arith.extsi %b : vector<2xi16> to vector<2xi32>
-  %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_insert_strided_slice_1d
-// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
-// CHECK-SAME:                    {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extui_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
-  %c = arith.extui %a : vector<3xi16> to vector<3xi32>
-  %d = arith.extui %b : vector<2xi16> to vector<2xi32>
-  %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
-  return %e : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_cst_2d
-// CHECK-SAME:    (%[[ARG:.+]]: vector<1x2xi8>)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant
-// CHECK-SAME{LITERAL}:            dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi16>
-// CHECK-NEXT:    %[[SRCE:.+]] = arith.extsi %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
-// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
-// CHECK-SAME:                    {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
-func.func @extsi_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
-  %cst = arith.constant dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi32>
-  %d = arith.extsi %a : vector<1x2xi8> to vector<1x2xi32>
-  %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
-  return %e : vector<2x3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_insert_strided_slice_cst_2d
-// CHECK-SAME:    (%[[ARG:.+]]: vector<1x2xi8>)
-// CHECK-NEXT:    %[[CST:.+]]  = arith.constant
-// CHECK-SAME{LITERAL}:            dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi16>
-// CHECK-NEXT:    %[[SRCE:.+]] = arith.extui %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
-// CHECK-NEXT:    %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
-// CHECK-NEXT:    %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
-// CHECK-SAME:                    {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
-func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
-  %cst = arith.constant dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi32>
-  %d = arith.extui %a : vector<1x2xi8> to vector<1x2xi32>
-  %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
-  return %e : vector<2x3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: i16)
-// CHECK-NEXT:    %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
-func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> {
-  %b = arith.extsi %a : i16 to i32
-  %r = vector.broadcast %b : i32 to vector<3xi32>
-  return %r : vector<3xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<3xi16>)
-// CHECK-NEXT:    %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<2x3xi32>
-func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> {
-  %b = arith.extui %a : vector<3xi16> to vector<3xi32>
-  %r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32>
-  return %r : vector<2x3xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
-// CHECK-NEXT:    %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3x2xi32>
-func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
-  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
-  %r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32>
-  return %r : vector<3x2xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<5x2x3xi16>)
-// CHECK-NEXT:    %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<2x3x5xi32>
-func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
-  %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
-  %r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32>
-  return %r : vector<2x3x5xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<2x3xi16>)
-// CHECK-NEXT:    %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<3x2xi32>
-func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
-  %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
-  %r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
-  return %r : vector<3x2xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<5x2x3xi16>)
-// CHECK-NEXT:    %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<2x3x5xi32>
-func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
-  %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
-  %r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32>
-  return %r : vector<2x3x5xi32>
-}
-
-// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<16xi16>)
-// CHECK-NEXT:    %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<16xi32>
-func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
-  %b = arith.extsi %a : vector<16xi16> to vector<16xi32>
-  %r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32>
-  return %r : vector<16xi32>
-}
-
-// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16
-// CHECK-SAME:    (%[[ARG:.+]]: vector<16xi16>)
-// CHECK-NEXT:    %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16>
-// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32>
-// CHECK-NEXT:    return %[[RET]] : vector<16xi32>
-func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
-  %b = arith.extui %a : vector<16xi16> to vector<16xi32>
-  %r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
-  return %r : vector<16xi32>
-}
diff --git a/mlir/test/Dialect/Linalg/int-narrowing.mlir b/mlir/test/Dialect/Linalg/int-narrowing.mlir
deleted file mode 100644
index 8063d504597a39..00000000000000
--- a/mlir/test/Dialect/Linalg/int-narrowing.mlir
+++ /dev/null
@@ -1,147 +0,0 @@
-// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
-// RUN:          --verify-diagnostics %s | FileCheck %s
-
-// Check that we can calculate `linalg.index` value bounds and use them to
-// optimize index casts.
-
-//===----------------------------------------------------------------------===//
-// arith.index_cast
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @linalg_indexcast_dim_0_i8
-// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
-// CHECK-NEXT:    %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i8
-// CHECK-NEXT:    %[[FP:.+]]  = arith.sitofp %[[INT]] : i8 to f16
-// CHECK-NEXT:    linalg.yield %[[FP]] : f16
-func.func @linalg_indexcast_dim_0_i8(%arg0: tensor<f16>) -> tensor<128xf16> {
-  %init = tensor.empty() : tensor<128xf16>
-  %res = linalg.generic {
-      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
-      iterator_types = ["parallel"]
-    }
-    ins(%arg0 : tensor<f16>)
-    outs(%init : tensor<128xf16>) {
-  ^bb0(%in: f16, %out: f16):
-    %idx = linalg.index 0 : index
-    %int = arith.index_cast %idx : index to i64
-    %fp = arith.sitofp %int : i64 to f16
-    linalg.yield %fp : f16
-  } -> tensor<128xf16>
-
-  return %res : tensor<128xf16>
-}
-
-// CHECK-LABEL: func @linalg_indexcast_dim_1_i16
-// CHECK:         %[[IDX:.+]] = linalg.index 1 : index
-// CHECK-NEXT:    %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i16
-// CHECK-NEXT:    %[[FP:.+]]  = arith.sitofp %[[INT]] : i16 to f16
-// CHECK-NEXT:    linalg.yield %[[FP]] : f16
-func.func @linalg_indexcast_dim_1_i16(%arg0: tensor<f16>, %arg1: tensor<?x129xf16>) -> tensor<?x129xf16> {
-  %res = linalg.generic {
-      indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>],
-      iterator_types = ["parallel", "parallel"]
-    }
-    ins(%arg0 : tensor<f16>)
-    outs(%arg1 : tensor<?x129xf16>) {
-  ^bb0(%in: f16, %out: f16):
-    %idx = linalg.index 1 : index
-    %int = arith.index_cast %idx : index to i64
-    %fp = arith.sitofp %int : i64 to f16
-    linalg.yield %fp : f16
-  } -> tensor<?x129xf16>
-
-  return %res : tensor<?x129xf16>
-}
-
-// CHECK-LABEL: func @linalg_indexcast_dynamic_dim_i64
-// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
-// CHECK-NEXT:    %[[INT:.+]] = arith.index_cast %[[IDX]] : index to i64
-// CHECK-NEXT:    %[[FP:.+]]  = arith.sitofp %[[INT]] : i64 to f16
-// CHECK-NEXT:    linalg.yield %[[FP]] : f16
-func.func @linalg_indexcast_dynamic_dim_i64(%arg0: tensor<f16>, %arg1: tensor<?xf16>) -> tensor<?xf16> {
-  %res = linalg.generic {
-      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
-      iterator_types = ["parallel"]
-    }
-    ins(%arg0 : tensor<f16>)
-    outs(%arg1 : tensor<?xf16>) {
-  ^bb0(%in: f16, %out: f16):
-    %idx = linalg.index 0 : index
-    %int = arith.index_cast %idx : index to i64
-    %fp = arith.sitofp %int : i64 to f16
-    linalg.yield %fp : f16
-  } -> tensor<?xf16>
-
-  return %res : tensor<?xf16>
-}
-
-//===----------------------------------------------------------------------===//
-// arith.index_castui
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @linalg_indexcastui_dim_0_i8
-// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
-// CHECK-NEXT:    %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i8
-// CHECK-NEXT:    %[[FP:.+]]  = arith.uitofp %[[INT]] : i8 to f16
-// CHECK-NEXT:    linalg.yield %[[FP]] : f16
-func.func @linalg_indexcastui_dim_0_i8(%arg0: tensor<f16>) -> tensor<256xf16> {
-  %init = tensor.empty() : tensor<256xf16>
-  %res = linalg.generic {
-      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
-      iterator_types = ["parallel"]
-    }
-    ins(%arg0 : tensor<f16>)
-    outs(%init : tensor<256xf16>) {
-  ^bb0(%in: f16, %out: f16):
-    %idx = linalg.index 0 : index
-    %int = arith.index_castui %idx : index to i64
-    %fp = arith.uitofp %int : i64 to f16
-    linalg.yield %fp : f16
-  } -> tensor<256xf16>
-
-  return %res : tensor<256xf16>
-}
-
-// CHECK-LABEL: func @linalg_indexcastui_dim_1_i16
-// CHECK:         %[[IDX:.+]] = linalg.index 1 : index
-// CHECK-NEXT:    %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i16
-// CHECK-NEXT:    %[[FP:.+]]  = arith.uitofp %[[INT]] : i16 to f16
-// CHECK-NEXT:    linalg.yield %[[FP]] : f16
-func.func @linalg_indexcastui_dim_1_i16(%arg0: tensor<f16>, %arg1: tensor<?x257xf16>) -> tensor<?x257xf16> {
-  %res = linalg.generic {
-      indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>],
-      iterator_types = ["parallel", "parallel"]
-    }
-    ins(%arg0 : tensor<f16>)
-    outs(%arg1 : tensor<?x257xf16>) {
-  ^bb0(%in: f16, %out: f16):
-    %idx = linalg.index 1 : index
-    %int = arith.index_castui %idx : index to i64
-    %fp = arith.uitofp %int : i64 to f16
-    linalg.yield %fp : f16
-  } -> tensor<?x257xf16>
-
-  return %res : tensor<?x257xf16>
-}
-
-// CHECK-LABEL: func @linalg_indexcastui_dynamic_dim_i64
-// CHECK:         %[[IDX:.+]] = linalg.index 0 : index
-// CHECK-NEXT:    %[[INT:.+]] = arith.index_castui %[[IDX]] : index to i64
-// CHECK-NEXT:    %[[FP:.+]]  = arith.uitofp %[[INT]] : i64 to f16
-// CHECK-NEXT:    linalg.yield %[[FP]] : f16
-func.func @linalg_indexcastui_dynamic_dim_i64(%arg0: tensor<f16>, %arg1: tensor<?xf16>) -> tensor<?xf16> {
-  %res = linalg.generic {
-      indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
-      iterator_types = ["parallel"]
-    }
-    ins(%arg0 : tensor<f16>)
-    outs(%arg1 : tensor<?xf16>) {
-  ^bb0(%in: f16, %out: f16):
-    %idx = linalg.index 0 : index
-    %int = arith.index_castui %idx : index to i64
-    %fp = arith.uitofp %int : i64 to f16
-    linalg.yield %fp : f16
-  } -> tensor<?xf16>
-
-  return %res : tensor<?xf16>
-}

>From b4275530a289255ee9147e679384654813d2e642 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 22:37:26 +0200
Subject: [PATCH 07/19] nits

---
 .../Transforms/IntRangeOptimizations.cpp      | 28 +++++++++----------
 1 file changed, 13 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 8c651076df2e5a..e026632d0201d7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -218,9 +218,10 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
       if (!type) {
         type = val.getType();
         continue;
-      } else if (type != val.getType()) {
-        return nullptr;
       }
+
+      if (type != val.getType())
+        return nullptr;
     }
   }
 
@@ -301,13 +302,11 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
   auto dstInt = cast<IntegerType>(dstType);
   if (dstInt.getWidth() < srcInt.getWidth()) {
     return builder.create<arith::TruncIOp>(loc, dstType, src);
-  } else {
-    return builder.create<arith::ExtUIOp>(loc, dstType, src);
   }
+  return builder.create<arith::ExtUIOp>(loc, dstType, src);
 }
 
-struct NarrowElementwise final
-    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
   NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
                     ArrayRef<unsigned> target)
       : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
@@ -316,7 +315,6 @@ struct NarrowElementwise final
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-
     std::optional<ConstantIntRanges> range =
         getOperandsRange(solver, op->getResults());
     if (!range)
@@ -370,8 +368,8 @@ struct NarrowElementwise final
   SmallVector<unsigned, 4> targetBitwidths;
 };
 
-struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
-  NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
+struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
+  NarrowCmpI(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
              ArrayRef<unsigned> target)
       : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
   }
@@ -421,8 +419,8 @@ struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
   SmallVector<unsigned, 4> targetBitwidths;
 };
 
-struct IntRangeOptimizationsPass
-    : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
+struct IntRangeOptimizationsPass final
+    : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
   void runOnOperation() override {
     Operation *op = getOperation();
@@ -446,8 +444,8 @@ struct IntRangeOptimizationsPass
   }
 };
 
-struct IntRangeNarrowingPass
-    : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
+struct IntRangeNarrowingPass final
+    : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
   using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
 
   void runOnOperation() override {
@@ -482,9 +480,9 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
 void mlir::arith::populateIntRangeNarrowingPatterns(
     RewritePatternSet &patterns, DataFlowSolver &solver,
     ArrayRef<unsigned> bitwidthsSupported) {
-  // Cmpi uses args ranges instead of results, run it with higher benefit,
+  // CmpI uses args ranges instead of results, run it with higher benefit,
   // as its argumens can be potentially replaced.
-  patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
+  patterns.add<NarrowCmpI>(patterns.getContext(), /*benefit*/ 10, solver,
                            bitwidthsSupported);
 
   patterns.add<NarrowElementwise>(patterns.getContext(), solver,

>From 1f8358b7933776aa9f98774b3b08348c3af4e408 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 23:01:37 +0200
Subject: [PATCH 08/19] comments

---
 .../Transforms/IntRangeOptimizations.cpp      | 22 +++++++++++++------
 1 file changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index e026632d0201d7..43ea7ee9a85b0e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -195,7 +195,8 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
-static Type checkArithType(Type type, unsigned targetBitwidth) {
+/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
+static Type checkIntType(Type type, unsigned targetBitwidth) {
   type = getElementTypeOrSelf(type);
   if (isa<IndexType>(type))
     return type;
@@ -207,6 +208,9 @@ static Type checkArithType(Type type, unsigned targetBitwidth) {
   return nullptr;
 }
 
+/// Check if op have same type for all operands and results and this type
+/// is suitable for truncation.
+/// Retuns args type or empty.
 static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
   if (op->getNumOperands() == 0 || op->getNumResults() == 0)
     return nullptr;
@@ -225,13 +229,14 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
     }
   }
 
-  return checkArithType(type, targetBitwidth);
+  return checkIntType(type, targetBitwidth);
 }
 
+/// Return union of all operands values ranges.
 static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
-                                                         ValueRange results) {
+                                                         ValueRange operands) {
   std::optional<ConstantIntRanges> ret;
-  for (Value value : results) {
+  for (Value value : operands) {
     auto *maybeInferredRange =
         solver.lookupState<IntegerValueRangeLattice>(value);
     if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
@@ -249,6 +254,8 @@ static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
   return ret;
 }
 
+/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
+/// return shaped type as well.
 static Type getTargetType(Type srcType, unsigned targetBitwidth) {
   auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
   if (auto shaped = dyn_cast<ShapedType>(srcType))
@@ -258,6 +265,7 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
   return dstType;
 }
 
+/// Check privided `range` is inside `smin, smax, umin, umax` bounds.
 static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
                        APInt umin, APInt umax) {
   auto sge = [](APInt val1, APInt val2) -> bool {
@@ -300,9 +308,9 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
 
   auto srcInt = cast<IntegerType>(srcType);
   auto dstInt = cast<IntegerType>(dstType);
-  if (dstInt.getWidth() < srcInt.getWidth()) {
+  if (dstInt.getWidth() < srcInt.getWidth())
     return builder.create<arith::TruncIOp>(loc, dstType, src);
-  }
+
   return builder.create<arith::ExtUIOp>(loc, dstType, src);
 }
 
@@ -385,7 +393,7 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
       return failure();
 
     for (unsigned targetBitwidth : targetBitwidths) {
-      Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+      Type srcType = checkIntType(lhs.getType(), targetBitwidth);
       if (!srcType)
         continue;
 

>From f7b5485ee2cc888716ba1593fdbb52eee185aea8 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 1 Nov 2024 23:43:16 +0100
Subject: [PATCH 09/19] add vector support

---
 .../Transforms/IntRangeOptimizations.cpp      | 20 +++++++------
 .../Dialect/Arith/int-range-narrowing.mlir    | 28 +++++++++++++++++++
 2 files changed, 40 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 43ea7ee9a85b0e..45e870eac180d0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -197,11 +197,11 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
 
 /// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
 static Type checkIntType(Type type, unsigned targetBitwidth) {
-  type = getElementTypeOrSelf(type);
-  if (isa<IndexType>(type))
+  Type elemType = getElementTypeOrSelf(type);
+  if (isa<IndexType>(elemType))
     return type;
 
-  if (auto intType = dyn_cast<IntegerType>(type))
+  if (auto intType = dyn_cast<IntegerType>(elemType))
     if (intType.getWidth() > targetBitwidth)
       return type;
 
@@ -298,16 +298,20 @@ static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
 
 static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
   Type srcType = src.getType();
-  assert(srcType.isIntOrIndex() && "Invalid src type");
-  assert(dstType.isIntOrIndex() && "Invalid dst type");
+  assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
+         "Mixing vector and non-vector types");
+  Type srcElemType = getElementTypeOrSelf(srcType);
+  Type dstElemType = getElementTypeOrSelf(dstType);
+  assert(srcElemType.isIntOrIndex() && "Invalid src type");
+  assert(dstElemType.isIntOrIndex() && "Invalid dst type");
   if (srcType == dstType)
     return src;
 
-  if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+  if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
     return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
 
-  auto srcInt = cast<IntegerType>(srcType);
-  auto dstInt = cast<IntegerType>(dstType);
+  auto srcInt = cast<IntegerType>(srcElemType);
+  auto dstInt = cast<IntegerType>(dstElemType);
   if (dstInt.getWidth() < srcInt.getWidth())
     return builder.create<arith::TruncIOp>(loc, dstType, src);
 
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index cd0a4c449913e1..1378fb1c3c98c9 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -30,6 +30,20 @@ func.func @test_addi() -> index {
   return %2 : index
 }
 
+// CHECK-LABEL: func @test_addi_vec
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : vector<4xindex>
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : vector<4xindex>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : vector<4xi8> to vector<4xindex>
+//       CHECK:  return %[[RES_CASTED]] : vector<4xindex>
+func.func @test_addi_vec() -> vector<4xindex> {
+  %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
+  %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
+  %2 = arith.addi %0, %1 : vector<4xindex>
+  return %2 : vector<4xindex>
+}
 
 // CHECK-LABEL: func @test_addi_i64
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
@@ -60,6 +74,20 @@ func.func @test_cmpi() -> i1 {
   return %2 : i1
 }
 
+// CHECK-LABEL: func @test_cmpi_vec
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex>
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : vector<4xindex> to vector<4xi8>
+//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : vector<4xi8>
+//       CHECK:  return %[[RES]] : vector<4xi1>
+func.func @test_cmpi_vec() -> vector<4xi1> {
+  %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : vector<4xindex>
+  %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : vector<4xindex>
+  %2 = arith.cmpi slt, %0, %1 : vector<4xindex>
+  return %2 : vector<4xi1>
+}
+
 //===----------------------------------------------------------------------===//
 // arith.addi
 //===----------------------------------------------------------------------===//

>From 976f7c72678b9e53a1d65d1e47bbf20f945b15d4 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 12:23:40 +0100
Subject: [PATCH 10/19] remove benefit

---
 .../Arith/Transforms/IntRangeOptimizations.cpp    | 15 ++++-----------
 1 file changed, 4 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 45e870eac180d0..80199ce9b618de 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -381,10 +381,8 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
 };
 
 struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
-  NarrowCmpI(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
-             ArrayRef<unsigned> target)
-      : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
-  }
+  NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
+      : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
 
   LogicalResult matchAndRewrite(arith::CmpIOp op,
                                 PatternRewriter &rewriter) const override {
@@ -492,13 +490,8 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
 void mlir::arith::populateIntRangeNarrowingPatterns(
     RewritePatternSet &patterns, DataFlowSolver &solver,
     ArrayRef<unsigned> bitwidthsSupported) {
-  // CmpI uses args ranges instead of results, run it with higher benefit,
-  // as its argumens can be potentially replaced.
-  patterns.add<NarrowCmpI>(patterns.getContext(), /*benefit*/ 10, solver,
-                           bitwidthsSupported);
-
-  patterns.add<NarrowElementwise>(patterns.getContext(), solver,
-                                  bitwidthsSupported);
+  patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
+                                              bitwidthsSupported);
 }
 
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {

>From 98a0da4b85095928f9caaaaee803e353b90723c6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 12:34:11 +0100
Subject: [PATCH 11/19] style fixes

---
 .../Arith/Transforms/IntRangeOptimizations.cpp        | 11 +++--------
 1 file changed, 3 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 80199ce9b618de..692b40830704eb 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -245,11 +245,7 @@ static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
     const ConstantIntRanges &inferredRange =
         maybeInferredRange->getValue().getValue();
 
-    if (!ret) {
-      ret = inferredRange;
-    } else {
-      ret = ret->rangeUnion(inferredRange);
-    }
+    ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
   }
   return ret;
 }
@@ -265,7 +261,7 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
   return dstType;
 }
 
-/// Check privided `range` is inside `smin, smax, umin, umax` bounds.
+/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
 static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
                        APInt umin, APInt umax) {
   auto sge = [](APInt val1, APInt val2) -> bool {
@@ -321,8 +317,7 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
 struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
   NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
                     ArrayRef<unsigned> target)
-      : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
-        targetBitwidths(target) {}
+      : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
 
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,

>From cb1741f27161f8041fed847f7c9f405d6204dc0a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 12:37:08 +0100
Subject: [PATCH 12/19] llvm concat

---
 .../Arith/Transforms/IntRangeOptimizations.cpp  | 17 +++++++----------
 1 file changed, 7 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 692b40830704eb..d2e36f9ef82391 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -216,17 +216,14 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
     return nullptr;
 
   Type type;
-  for (auto range :
-       {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
-    for (Value val : range) {
-      if (!type) {
-        type = val.getType();
-        continue;
-      }
-
-      if (type != val.getType())
-        return nullptr;
+  for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+    if (!type) {
+      type = val.getType();
+      continue;
     }
+
+    if (type != val.getType())
+      return nullptr;
   }
 
   return checkIntType(type, targetBitwidth);

>From 581760a562afb1ab4e20366bcf7663dbd58412dc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 13:10:26 +0100
Subject: [PATCH 13/19] checkIntType refac

---
 .../Transforms/IntRangeOptimizations.cpp      | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index d2e36f9ef82391..79fdd16575e8f0 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -196,24 +196,23 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
 };
 
 /// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
-static Type checkIntType(Type type, unsigned targetBitwidth) {
+static bool checkIntType(Type type, unsigned targetBitwidth) {
   Type elemType = getElementTypeOrSelf(type);
   if (isa<IndexType>(elemType))
-    return type;
+    return true;
 
   if (auto intType = dyn_cast<IntegerType>(elemType))
     if (intType.getWidth() > targetBitwidth)
-      return type;
+      return true;
 
-  return nullptr;
+  return false;
 }
 
 /// Check if op have same type for all operands and results and this type
 /// is suitable for truncation.
-/// Retuns args type or empty.
-static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
   if (op->getNumOperands() == 0 || op->getNumResults() == 0)
-    return nullptr;
+    return false;
 
   Type type;
   for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
@@ -223,7 +222,7 @@ static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
     }
 
     if (type != val.getType())
-      return nullptr;
+      return false;
   }
 
   return checkIntType(type, targetBitwidth);
@@ -325,10 +324,11 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
       return failure();
 
     for (unsigned targetBitwidth : targetBitwidths) {
-      Type srcType = checkElementwiseOpType(op, targetBitwidth);
-      if (!srcType)
+      if (!checkElementwiseOpType(op, targetBitwidth))
         continue;
 
+      Type srcType = op->getResult(0).getType();
+
       // We are truncating op args to the desired bitwidth before the op and
       // then extending op results back to the original width after. extui and
       // exti will produce different results for negative values, so limit
@@ -387,8 +387,8 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
       return failure();
 
     for (unsigned targetBitwidth : targetBitwidths) {
-      Type srcType = checkIntType(lhs.getType(), targetBitwidth);
-      if (!srcType)
+      Type srcType = lhs.getType();
+      if (!checkIntType(srcType, targetBitwidth))
         continue;
 
       auto smin = APInt::getSignedMinValue(targetBitwidth);

>From fdcbb5f2cd083676cda570850b84870d1e8534f4 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 13:19:08 +0100
Subject: [PATCH 14/19] add test

---
 .../Dialect/Arith/int-range-narrowing.mlir    | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 1378fb1c3c98c9..054d8bba9f7b15 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -88,6 +88,27 @@ func.func @test_cmpi_vec() -> vector<4xi1> {
   return %2 : vector<4xi1>
 }
 
+// CHECK-LABEL: func @test_add_cmpi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[C:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+//       CHECK:  %[[RES1_CASTED1:.*]] = arith.index_castui %[[RES1]] : i8 to index
+//       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] : index to i8
+//       CHECK:  %[[RES1_CASTED2:.*]] = arith.index_castui %[[RES1_CASTED1]] : index to i8
+//       CHECK:  %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1_CASTED2]] : i8
+//       CHECK:  return %[[RES2]] : i1
+func.func @test_add_cmpi() -> i1 {
+  %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %3 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %4 = arith.addi %0, %1 : index
+  %5 = arith.cmpi slt, %3, %4 : index
+  return %5 : i1
+}
+
 //===----------------------------------------------------------------------===//
 // arith.addi
 //===----------------------------------------------------------------------===//

>From c2c9e39b8b0d64b981041923a04162a9f9649692 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 13:30:50 +0100
Subject: [PATCH 15/19] add comment

---
 mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 79fdd16575e8f0..f5b2d0d2837965 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -465,6 +465,9 @@ struct IntRangeNarrowingPass final
     populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
 
     GreedyRewriteConfig config;
+    // We specifically need bottom-up traversal as cmpi pattern needs range
+    // data, attched to it's original arguments.
+    config.useTopDownTraversal = false;
     config.listener = &listener;
 
     if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))

>From 1d88642fc8584e7b6fb92bdcca0eb3206c9d331b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 13:47:10 +0100
Subject: [PATCH 16/19] fold index cast chain

---
 .../Transforms/IntRangeOptimizations.cpp      | 30 +++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    |  4 +--
 2 files changed, 31 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index f5b2d0d2837965..06406b0852c5d6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -421,6 +421,35 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
   SmallVector<unsigned, 4> targetBitwidths;
 };
 
+/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
+/// This pattern assumes all passed `targetBitwidths` are not wider than index
+/// type.
+struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
+  FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
+      : OpRewritePattern(context), targetBitwidths(target) {}
+
+  LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
+                                PatternRewriter &rewriter) const override {
+    auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
+    if (!srcOp)
+      return failure();
+
+    Value src = srcOp.getIn();
+    if (src.getType() != op.getType())
+      return failure();
+
+    auto intType = dyn_cast<IntegerType>(op.getType());
+    if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
+      return failure();
+
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+
+private:
+  SmallVector<unsigned, 4> targetBitwidths;
+};
+
 struct IntRangeOptimizationsPass final
     : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -487,6 +516,7 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
     ArrayRef<unsigned> bitwidthsSupported) {
   patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
                                               bitwidthsSupported);
+  patterns.add<FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported);
 }
 
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 054d8bba9f7b15..5ad89805a1b456 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -95,10 +95,8 @@ func.func @test_cmpi_vec() -> vector<4xi1> {
 //       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
 //       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
 //       CHECK:  %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
-//       CHECK:  %[[RES1_CASTED1:.*]] = arith.index_castui %[[RES1]] : i8 to index
 //       CHECK:  %[[C_CASTED:.*]] = arith.index_castui %[[C]] : index to i8
-//       CHECK:  %[[RES1_CASTED2:.*]] = arith.index_castui %[[RES1_CASTED1]] : index to i8
-//       CHECK:  %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1_CASTED2]] : i8
+//       CHECK:  %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8
 //       CHECK:  return %[[RES2]] : i1
 func.func @test_add_cmpi() -> i1 {
   %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index

>From 08297538bee33e479dccc0f8902b893758e5a712 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 13:51:49 +0100
Subject: [PATCH 17/19] test

---
 .../Dialect/Arith/int-range-narrowing.mlir    | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 5ad89805a1b456..8893f299177ceb 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -107,6 +107,25 @@ func.func @test_add_cmpi() -> i1 {
   return %5 : i1
 }
 
+// CHECK-LABEL: func @test_add_cmpi_i64
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : i64, smin = 0 : i64, umax = 10 : i64, umin = 0 : i64} : i64
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : i64, smin = 0 : i64, umax = 10 : i64, umin = 0 : i64} : i64
+//       CHECK:  %[[C:.*]] = test.with_bounds {smax = 10 : i64, smin = 0 : i64, umax = 10 : i64, umin = 0 : i64} : i64
+//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8
+//       CHECK:  %[[RES1:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+//       CHECK:  %[[C_CASTED:.*]] = arith.trunci %[[C]] : i64 to i8
+//       CHECK:  %[[RES2:.*]] = arith.cmpi slt, %[[C_CASTED]], %[[RES1]] : i8
+//       CHECK:  return %[[RES2]] : i1
+func.func @test_add_cmpi_i64() -> i1 {
+  %0 = test.with_bounds { umin = 0 : i64, umax = 10 : i64, smin = 0 : i64, smax = 10 : i64 } : i64
+  %1 = test.with_bounds { umin = 0 : i64, umax = 10 : i64, smin = 0 : i64, smax = 10 : i64 } : i64
+  %3 = test.with_bounds { umin = 0 : i64, umax = 10 : i64, smin = 0 : i64, smax = 10 : i64 } : i64
+  %4 = arith.addi %0, %1 : i64
+  %5 = arith.cmpi slt, %3, %4 : i64
+  return %5 : i1
+}
+
 //===----------------------------------------------------------------------===//
 // arith.addi
 //===----------------------------------------------------------------------===//

>From 324b8bdc84311a093a48b6fca8c80b178ce92167 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 13:51:56 +0100
Subject: [PATCH 18/19] fix comment

---
 mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 06406b0852c5d6..450d3972bb99df 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -495,7 +495,7 @@ struct IntRangeNarrowingPass final
 
     GreedyRewriteConfig config;
     // We specifically need bottom-up traversal as cmpi pattern needs range
-    // data, attched to it's original arguments.
+    // data, attached to its original argument values.
     config.useTopDownTraversal = false;
     config.listener = &listener;
 

>From 262e991e59ece054971a8b05a64ee6e27b6c20ae Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 3 Nov 2024 13:55:22 +0100
Subject: [PATCH 19/19] update pass desc

---
 mlir/include/mlir/Dialect/Arith/Transforms/Passes.td | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 98f90d120fa1c6..1d37314885d932 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -55,6 +55,9 @@ def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
   let description = [{
     This pass runs integer range analysis and tries to narrow arith ops to the
     specified bitwidth based on its results.
+
+    `bitwidthsSupported` assumed to be not wider than `index` type.
+    TODO: get index width from DLTI.
   }];
 
   let options = [



More information about the Mlir-commits mailing list