[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)

Ivan Butygin llvmlistbot at llvm.org
Wed Oct 16 05:13:48 PDT 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/112404

>From 584e8e5fc5ade2df915f02d85ff119b676b77d8b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 16:32:32 +0200
Subject: [PATCH 1/5] [mlir] Add `arith-int-range-narrowing` pass

This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program.
They ideally should be unified in the future, but it's a task for the future.
---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   6 +
 .../mlir/Dialect/Arith/Transforms/Passes.td   |  19 ++
 .../Transforms/IntRangeOptimizations.cpp      | 260 ++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    |  57 ++++
 4 files changed, 342 insertions(+)
 create mode 100644 mlir/test/Dialect/Arith/int-range-narrowing.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..beeefea497d9b2 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,6 +73,12 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
+/// Add patterns for int range based norrowing.
+void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                       DataFlowSolver &solver,
+                                       unsigned targetBitwidth);
+
+// TODO: merge these two narrowing passes.
 /// Add patterns for integer bitwidth narrowing.
 void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
                                        const ArithIntNarrowingOptions &options);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 1517f71f1a7c90..8c565f6489638e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -50,6 +50,25 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   ];
 }
 
+def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
+  let summary = "Reduce integer operations bitwidth based on integer range analysis";
+  let description = [{
+    This pass runs integer range analysis and tries to narrow arith ops to the
+    specified bitwidth based on its results.
+  }];
+
+  let options = [
+    Option<"targetBitwidth", "target-bitwidth", "unsigned",
+           /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+  ];
+
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
+}
+
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
   let summary = "Emulate operations on unsupported floats with extf/truncf";
   let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd..4e7b44c939e796 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -24,6 +26,9 @@
 namespace mlir::arith {
 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+
+#define GEN_PASS_DEF_ARITHINTRANGENARROWING
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 } // namespace mlir::arith
 
 using namespace mlir;
@@ -184,6 +189,223 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
+static Type checkArithType(Type type, unsigned targetBitwidth) {
+  type = getElementTypeOrSelf(type);
+  if (isa<IndexType>(type))
+    return type;
+
+  if (auto intType = dyn_cast<IntegerType>(type))
+    if (intType.getWidth() > targetBitwidth)
+      return type;
+
+  return nullptr;
+}
+
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return nullptr;
+
+  Type type;
+  for (auto range :
+       {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+    for (Value val : range) {
+      if (!type) {
+        type = val.getType();
+        continue;
+      } else if (type != val.getType()) {
+        return nullptr;
+      }
+    }
+  }
+
+  return checkArithType(type, targetBitwidth);
+}
+
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+                                                         ValueRange results) {
+  std::optional<ConstantIntRanges> ret;
+  for (Value value : results) {
+    auto *maybeInferredRange =
+        solver.lookupState<IntegerValueRangeLattice>(value);
+    if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+      return std::nullopt;
+
+    const ConstantIntRanges &inferredRange =
+        maybeInferredRange->getValue().getValue();
+
+    if (!ret) {
+      ret = inferredRange;
+    } else {
+      ret = ret->rangeUnion(inferredRange);
+    }
+  }
+  return ret;
+}
+
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+  if (auto shaped = dyn_cast<ShapedType>(srcType))
+    return shaped.clone(dstType);
+
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  return dstType;
+}
+
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+                       APInt umin, APInt umax) {
+  auto sge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sge(val2);
+  };
+  auto sle = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sle(val2);
+  };
+  auto uge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.uge(val2);
+  };
+  auto ule = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.ule(val2);
+  };
+  return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+         uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+  Type srcType = src.getType();
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  assert(dstType.isIntOrIndex() && "Invalid dst type");
+  if (srcType == dstType)
+    return src;
+
+  if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+    return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+  auto srcInt = cast<IntegerType>(srcType);
+  auto dstInt = cast<IntegerType>(dstType);
+  if (dstInt.getWidth() < srcInt.getWidth()) {
+    return builder.create<arith::TruncIOp>(loc, dstType, src);
+  } else {
+    return builder.create<arith::ExtUIOp>(loc, dstType, src);
+  }
+}
+
+struct NarrowElementwise final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+      : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
+        targetBitwidth(target) {}
+
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    Type srcType = checkElementwiseOpType(op, targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, op->getResults());
+    if (!range)
+      return failure();
+
+    // We are truncating op args to the desired bitwidth before the op and then
+    // extending op results back to the original width after.
+    // extui and exti will produce different results for negative values, so
+    // limit signed range to non-negative values.
+    auto smin = APInt::getZero(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.modifyOpInPlace(newOp, [&]() {
+      for (OpResult res : newOp->getResults()) {
+        res.setType(targetType);
+      }
+    });
+    SmallVector<Value> newResults;
+    for (Value res : newOp->getResults())
+      newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
+struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
+  NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
+             unsigned target)
+      : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+
+  LogicalResult matchAndRewrite(arith::CmpIOp op,
+                                PatternRewriter &rewriter) const override {
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, {lhs, rhs});
+    if (!range)
+      return failure();
+
+    auto smin = APInt::getSignedMinValue(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
 struct IntRangeOptimizationsPass
     : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -208,6 +430,32 @@ struct IntRangeOptimizationsPass
       signalPassFailure();
   }
 };
+
+struct IntRangeNarrowingPass
+    : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
+  using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<IntegerRangeAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    DataFlowListener listener(solver);
+
+    RewritePatternSet patterns(ctx);
+    populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+
+    GreedyRewriteConfig config;
+    config.listener = &listener;
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+      signalPassFailure();
+  }
+};
 } // namespace
 
 void mlir::arith::populateIntRangeOptimizationsPatterns(
@@ -216,6 +464,18 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
 }
 
+void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                                    DataFlowSolver &solver,
+                                                    unsigned targetBitwidth) {
+  // Cmpi uses args ranges instead of results, run it with higher benefit,
+  // as its argumens can be potentially replaced.
+  patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
+                           targetBitwidth);
+
+  patterns.add<NarrowElementwise>(patterns.getContext(), solver,
+                                  targetBitwidth);
+}
+
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
new file mode 100644
index 00000000000000..d85cb3384061be
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32"  %s | FileCheck %s
+
+// Do not truncate negative values
+// CHECK-LABEL: func @test_addi_neg
+//       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+//       CHECK:  return %[[RES]] : index
+func.func @test_addi_neg() -> index {
+  %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = -1 : index, smin = -1 : index, smax = 0 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+// CHECK-LABEL: func @test_addi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+//       CHECK:  return %[[RES_CASTED]] : index
+func.func @test_addi() -> index {
+  %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+  %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+
+// CHECK-LABEL: func @test_addi_i64
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
+//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+//       CHECK:  return %[[RES_CASTED]] : i64
+func.func @test_addi_i64() -> i64 {
+  %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
+  %1 = test.with_bounds { umin = 6 : i64, umax = 7 : i64, smin = 6 : i64, smax = 7 : i64 } : i64
+  %2 = arith.addi %0, %1 : i64
+  return %2 : i64
+}
+
+// CHECK-LABEL: func @test_cmpi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  return %[[RES]] : i1
+func.func @test_cmpi() -> i1 {
+  %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %2 = arith.cmpi slt, %0, %1 : index
+  return %2 : i1
+}

>From 696b8a5fc822e9beeee4e7b4e18ac12a27f10c06 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 18:59:07 +0200
Subject: [PATCH 2/5] typo

---
 mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index beeefea497d9b2..ce8e819708b190 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,7 +73,7 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
-/// Add patterns for int range based norrowing.
+/// Add patterns for int range based narrowing.
 void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
                                        unsigned targetBitwidth);

>From 106beda0062847cc8ae6d624b825081e2702b032 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 13:05:29 +0200
Subject: [PATCH 3/5] use list of bitwidths instead of 1

---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   2 +-
 .../mlir/Dialect/Arith/Transforms/Passes.td   |   4 +-
 .../Transforms/IntRangeOptimizations.cpp      | 141 ++++++++++--------
 .../Dialect/Arith/int-range-narrowing.mlir    |   2 +-
 4 files changed, 79 insertions(+), 70 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index ce8e819708b190..4d386827571dfa 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 /// Add patterns for int range based narrowing.
 void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
-                                       unsigned targetBitwidth);
+                                       ArrayRef<unsigned> bitwidthsSupported);
 
 // TODO: merge these two narrowing passes.
 /// Add patterns for integer bitwidth narrowing.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 8c565f6489638e..898d74249af618 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -58,8 +58,8 @@ def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
   }];
 
   let options = [
-    Option<"targetBitwidth", "target-bitwidth", "unsigned",
-           /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+    ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
+               "Integer bitwidths supported">,
   ];
 
   // Explicitly depend on "arith" because this pass could create operations in
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 4e7b44c939e796..861602f2032172 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -302,108 +302,117 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
 
 struct NarrowElementwise final
     : public OpTraitRewritePattern<OpTrait::Elementwise> {
-  NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+  NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
+                    ArrayRef<unsigned> target)
       : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
-        targetBitwidth(target) {}
+        targetBitwidths(target) {}
 
   using OpTraitRewritePattern::OpTraitRewritePattern;
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    Type srcType = checkElementwiseOpType(op, targetBitwidth);
-    if (!srcType)
-      return failure();
 
     std::optional<ConstantIntRanges> range =
         getOperandsRange(solver, op->getResults());
     if (!range)
       return failure();
 
-    // We are truncating op args to the desired bitwidth before the op and then
-    // extending op results back to the original width after.
-    // extui and exti will produce different results for negative values, so
-    // limit signed range to non-negative values.
-    auto smin = APInt::getZero(targetBitwidth);
-    auto smax = APInt::getSignedMaxValue(targetBitwidth);
-    auto umin = APInt::getMinValue(targetBitwidth);
-    auto umax = APInt::getMaxValue(targetBitwidth);
-    if (!checkRange(*range, smin, smax, umin, umax))
-      return failure();
+    for (unsigned targetBitwidth : targetBitwidths) {
+      Type srcType = checkElementwiseOpType(op, targetBitwidth);
+      if (!srcType)
+        continue;
 
-    Type targetType = getTargetType(srcType, targetBitwidth);
-    if (targetType == srcType)
-      return failure();
+      // We are truncating op args to the desired bitwidth before the op and
+      // then extending op results back to the original width after. extui and
+      // exti will produce different results for negative values, so limit
+      // signed range to non-negative values.
+      auto smin = APInt::getZero(targetBitwidth);
+      auto smax = APInt::getSignedMaxValue(targetBitwidth);
+      auto umin = APInt::getMinValue(targetBitwidth);
+      auto umax = APInt::getMaxValue(targetBitwidth);
+      if (!checkRange(*range, smin, smax, umin, umax))
+        continue;
 
-    Location loc = op->getLoc();
-    IRMapping mapping;
-    for (Value arg : op->getOperands()) {
-      Value newArg = doCast(rewriter, loc, arg, targetType);
-      mapping.map(arg, newArg);
-    }
+      Type targetType = getTargetType(srcType, targetBitwidth);
+      if (targetType == srcType)
+        continue;
 
-    Operation *newOp = rewriter.clone(*op, mapping);
-    rewriter.modifyOpInPlace(newOp, [&]() {
-      for (OpResult res : newOp->getResults()) {
-        res.setType(targetType);
+      Location loc = op->getLoc();
+      IRMapping mapping;
+      for (Value arg : op->getOperands()) {
+        Value newArg = doCast(rewriter, loc, arg, targetType);
+        mapping.map(arg, newArg);
       }
-    });
-    SmallVector<Value> newResults;
-    for (Value res : newOp->getResults())
-      newResults.emplace_back(doCast(rewriter, loc, res, srcType));
 
-    rewriter.replaceOp(op, newResults);
-    return success();
+      Operation *newOp = rewriter.clone(*op, mapping);
+      rewriter.modifyOpInPlace(newOp, [&]() {
+        for (OpResult res : newOp->getResults()) {
+          res.setType(targetType);
+        }
+      });
+      SmallVector<Value> newResults;
+      for (Value res : newOp->getResults())
+        newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+      rewriter.replaceOp(op, newResults);
+      return success();
+    }
+    return failure();
   }
 
 private:
   DataFlowSolver &solver;
-  unsigned targetBitwidth;
+  SmallVector<unsigned, 4> targetBitwidths;
 };
 
 struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
   NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
-             unsigned target)
-      : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+             ArrayRef<unsigned> target)
+      : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
+  }
 
   LogicalResult matchAndRewrite(arith::CmpIOp op,
                                 PatternRewriter &rewriter) const override {
     Value lhs = op.getLhs();
     Value rhs = op.getRhs();
 
-    Type srcType = checkArithType(lhs.getType(), targetBitwidth);
-    if (!srcType)
-      return failure();
-
     std::optional<ConstantIntRanges> range =
         getOperandsRange(solver, {lhs, rhs});
     if (!range)
       return failure();
 
-    auto smin = APInt::getSignedMinValue(targetBitwidth);
-    auto smax = APInt::getSignedMaxValue(targetBitwidth);
-    auto umin = APInt::getMinValue(targetBitwidth);
-    auto umax = APInt::getMaxValue(targetBitwidth);
-    if (!checkRange(*range, smin, smax, umin, umax))
-      return failure();
+    for (unsigned targetBitwidth : targetBitwidths) {
+      Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+      if (!srcType)
+        continue;
 
-    Type targetType = getTargetType(srcType, targetBitwidth);
-    if (targetType == srcType)
-      return failure();
+      auto smin = APInt::getSignedMinValue(targetBitwidth);
+      auto smax = APInt::getSignedMaxValue(targetBitwidth);
+      auto umin = APInt::getMinValue(targetBitwidth);
+      auto umax = APInt::getMaxValue(targetBitwidth);
+      if (!checkRange(*range, smin, smax, umin, umax))
+        continue;
 
-    Location loc = op->getLoc();
-    IRMapping mapping;
-    for (Value arg : op->getOperands()) {
-      Value newArg = doCast(rewriter, loc, arg, targetType);
-      mapping.map(arg, newArg);
-    }
+      Type targetType = getTargetType(srcType, targetBitwidth);
+      if (targetType == srcType)
+        continue;
 
-    Operation *newOp = rewriter.clone(*op, mapping);
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
+      Location loc = op->getLoc();
+      IRMapping mapping;
+      for (Value arg : op->getOperands()) {
+        Value newArg = doCast(rewriter, loc, arg, targetType);
+        mapping.map(arg, newArg);
+      }
+
+      Operation *newOp = rewriter.clone(*op, mapping);
+      rewriter.replaceOp(op, newOp->getResults());
+      return success();
+    }
+    return failure();
   }
 
 private:
   DataFlowSolver &solver;
-  unsigned targetBitwidth;
+  SmallVector<unsigned, 4> targetBitwidths;
 };
 
 struct IntRangeOptimizationsPass
@@ -447,7 +456,7 @@ struct IntRangeNarrowingPass
     DataFlowListener listener(solver);
 
     RewritePatternSet patterns(ctx);
-    populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+    populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
 
     GreedyRewriteConfig config;
     config.listener = &listener;
@@ -464,16 +473,16 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
 }
 
-void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
-                                                    DataFlowSolver &solver,
-                                                    unsigned targetBitwidth) {
+void mlir::arith::populateIntRangeNarrowingPatterns(
+    RewritePatternSet &patterns, DataFlowSolver &solver,
+    ArrayRef<unsigned> bitwidthsSupported) {
   // Cmpi uses args ranges instead of results, run it with higher benefit,
   // as its argumens can be potentially replaced.
   patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
-                           targetBitwidth);
+                           bitwidthsSupported);
 
   patterns.add<NarrowElementwise>(patterns.getContext(), solver,
-                                  targetBitwidth);
+                                  bitwidthsSupported);
 }
 
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index d85cb3384061be..6283284567dce8 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32"  %s | FileCheck %s
+// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s
 
 // Do not truncate negative values
 // CHECK-LABEL: func @test_addi_neg

>From b7053b25a2840cbff6878b7c252c97ff02497591 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 13:12:00 +0200
Subject: [PATCH 4/5] update tests

---
 .../Dialect/Arith/int-range-narrowing.mlir    | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 6283284567dce8..8b73f02fd214a2 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s
+// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s
 
 // Do not truncate negative values
 // CHECK-LABEL: func @test_addi_neg
@@ -14,10 +14,10 @@ func.func @test_addi_neg() -> index {
 // CHECK-LABEL: func @test_addi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
-//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
-//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i8 to index
 //       CHECK:  return %[[RES_CASTED]] : index
 func.func @test_addi() -> index {
   %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
@@ -30,10 +30,10 @@ func.func @test_addi() -> index {
 // CHECK-LABEL: func @test_addi_i64
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
-//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
-//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
-//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
-//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i8
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i8
+//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i8 to i64
 //       CHECK:  return %[[RES_CASTED]] : i64
 func.func @test_addi_i64() -> i64 {
   %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
@@ -45,9 +45,9 @@ func.func @test_addi_i64() -> i64 {
 // CHECK-LABEL: func @test_cmpi
 //       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
 //       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
-//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
-//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
-//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i8
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i8
+//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i8
 //       CHECK:  return %[[RES]] : i1
 func.func @test_cmpi() -> i1 {
   %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index

>From b5f6014f35be7f17bcf9688d46ab209ad288529b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 16 Oct 2024 14:13:06 +0200
Subject: [PATCH 5/5] more tests

---
 .../Dialect/Arith/int-range-narrowing.mlir    | 142 ++++++++++++++++++
 1 file changed, 142 insertions(+)

diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 8b73f02fd214a2..cd0a4c449913e1 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=1,8,16,24,32" %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// Some basic tests
+//===----------------------------------------------------------------------===//
+
 // Do not truncate negative values
 // CHECK-LABEL: func @test_addi_neg
 //       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
@@ -55,3 +59,141 @@ func.func @test_cmpi() -> i1 {
   %2 = arith.cmpi slt, %0, %1 : index
   return %2 : i1
 }
+
+//===----------------------------------------------------------------------===//
+// arith.addi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @addi_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @addi_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[ADD]] : i32
+func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because we cannot reduce the bitwidth
+// below i16, given the pass options set.
+//
+// CHECK-LABEL: func.func @addi_extsi_i16
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i16
+// CHECK-NEXT:    return %[[ADD]] : i16
+func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
+  %a = arith.extsi %lhs : i8 to i16
+  %b = arith.extsi %rhs : i8 to i16
+  %r = arith.addi %a, %b : i16
+  return %r : i16
+}
+
+//===----------------------------------------------------------------------===//
+// arith.subi
+//===----------------------------------------------------------------------===//
+
+// This patterns should only apply to `arith.subi` ops with sign-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @subi_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[SUB]] : i32
+func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @subi_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[ADD]] : i32
+func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
+//===----------------------------------------------------------------------===//
+// arith.muli
+//===----------------------------------------------------------------------===//
+
+// TODO: This should be optimized into i16
+// CHECK-LABEL: func.func @muli_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i24 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// We do not expect this case to be optimized because given n-bit operands,
+// arith.muli produces 2n bits of result.
+//
+// CHECK-LABEL: func.func @muli_extsi_i32
+// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT:    %[[RET:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
+  %a = arith.extsi %lhs : i16 to i32
+  %b = arith.extsi %rhs : i16 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @muli_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[MUL]] : i32
+func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}



More information about the Mlir-commits mailing list