[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)

Ivan Butygin llvmlistbot at llvm.org
Tue Oct 15 10:00:32 PDT 2024


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/112404

This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`. 
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program. They ideally should be unified, but it's a task for the future.

>From 584e8e5fc5ade2df915f02d85ff119b676b77d8b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 16:32:32 +0200
Subject: [PATCH 1/2] [mlir] Add `arith-int-range-narrowing` pass

This pass intended to narrow integer calculations to the specific bitwidth, using `IntegerRangeAnalysis`.
We already have the `arith-int-narrowing` pass, but it mostly only doing local analysis, while `IntegerRangeAnalysis` analyses entire program.
They ideally should be unified in the future, but it's a task for the future.
---
 .../mlir/Dialect/Arith/Transforms/Passes.h    |   6 +
 .../mlir/Dialect/Arith/Transforms/Passes.td   |  19 ++
 .../Transforms/IntRangeOptimizations.cpp      | 260 ++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    |  57 ++++
 4 files changed, 342 insertions(+)
 create mode 100644 mlir/test/Dialect/Arith/int-range-narrowing.mlir

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index aee64475171a43..beeefea497d9b2 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,6 +73,12 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
+/// Add patterns for int range based norrowing.
+void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                       DataFlowSolver &solver,
+                                       unsigned targetBitwidth);
+
+// TODO: merge these two narrowing passes.
 /// Add patterns for integer bitwidth narrowing.
 void populateArithIntNarrowingPatterns(RewritePatternSet &patterns,
                                        const ArithIntNarrowingOptions &options);
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index 1517f71f1a7c90..8c565f6489638e 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -50,6 +50,25 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   ];
 }
 
+def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
+  let summary = "Reduce integer operations bitwidth based on integer range analysis";
+  let description = [{
+    This pass runs integer range analysis and tries to narrow arith ops to the
+    specified bitwidth based on its results.
+  }];
+
+  let options = [
+    Option<"targetBitwidth", "target-bitwidth", "unsigned",
+           /*default=*/"32", "Target bitwidth this pass will try to narrow to">,
+  ];
+
+  // Explicitly depend on "arith" because this pass could create operations in
+  // `arith` out of thin air in some cases.
+  let dependentDialects = [
+    "::mlir::arith::ArithDialect"
+  ];
+}
+
 def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
   let summary = "Emulate operations on unsupported floats with extf/truncf";
   let description = [{
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 521138c1f6f4cd..4e7b44c939e796 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -24,6 +26,9 @@
 namespace mlir::arith {
 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+
+#define GEN_PASS_DEF_ARITHINTRANGENARROWING
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 } // namespace mlir::arith
 
 using namespace mlir;
@@ -184,6 +189,223 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
+static Type checkArithType(Type type, unsigned targetBitwidth) {
+  type = getElementTypeOrSelf(type);
+  if (isa<IndexType>(type))
+    return type;
+
+  if (auto intType = dyn_cast<IntegerType>(type))
+    if (intType.getWidth() > targetBitwidth)
+      return type;
+
+  return nullptr;
+}
+
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return nullptr;
+
+  Type type;
+  for (auto range :
+       {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+    for (Value val : range) {
+      if (!type) {
+        type = val.getType();
+        continue;
+      } else if (type != val.getType()) {
+        return nullptr;
+      }
+    }
+  }
+
+  return checkArithType(type, targetBitwidth);
+}
+
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+                                                         ValueRange results) {
+  std::optional<ConstantIntRanges> ret;
+  for (Value value : results) {
+    auto *maybeInferredRange =
+        solver.lookupState<IntegerValueRangeLattice>(value);
+    if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+      return std::nullopt;
+
+    const ConstantIntRanges &inferredRange =
+        maybeInferredRange->getValue().getValue();
+
+    if (!ret) {
+      ret = inferredRange;
+    } else {
+      ret = ret->rangeUnion(inferredRange);
+    }
+  }
+  return ret;
+}
+
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+  if (auto shaped = dyn_cast<ShapedType>(srcType))
+    return shaped.clone(dstType);
+
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  return dstType;
+}
+
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+                       APInt umin, APInt umax) {
+  auto sge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sge(val2);
+  };
+  auto sle = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sle(val2);
+  };
+  auto uge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.uge(val2);
+  };
+  auto ule = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.ule(val2);
+  };
+  return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+         uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+  Type srcType = src.getType();
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  assert(dstType.isIntOrIndex() && "Invalid dst type");
+  if (srcType == dstType)
+    return src;
+
+  if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+    return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+  auto srcInt = cast<IntegerType>(srcType);
+  auto dstInt = cast<IntegerType>(dstType);
+  if (dstInt.getWidth() < srcInt.getWidth()) {
+    return builder.create<arith::TruncIOp>(loc, dstType, src);
+  } else {
+    return builder.create<arith::ExtUIOp>(loc, dstType, src);
+  }
+}
+
+struct NarrowElementwise final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
+      : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
+        targetBitwidth(target) {}
+
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    Type srcType = checkElementwiseOpType(op, targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, op->getResults());
+    if (!range)
+      return failure();
+
+    // We are truncating op args to the desired bitwidth before the op and then
+    // extending op results back to the original width after.
+    // extui and exti will produce different results for negative values, so
+    // limit signed range to non-negative values.
+    auto smin = APInt::getZero(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.modifyOpInPlace(newOp, [&]() {
+      for (OpResult res : newOp->getResults()) {
+        res.setType(targetType);
+      }
+    });
+    SmallVector<Value> newResults;
+    for (Value res : newOp->getResults())
+      newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
+struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
+  NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
+             unsigned target)
+      : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
+
+  LogicalResult matchAndRewrite(arith::CmpIOp op,
+                                PatternRewriter &rewriter) const override {
+    Value lhs = op.getLhs();
+    Value rhs = op.getRhs();
+
+    Type srcType = checkArithType(lhs.getType(), targetBitwidth);
+    if (!srcType)
+      return failure();
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, {lhs, rhs});
+    if (!range)
+      return failure();
+
+    auto smin = APInt::getSignedMinValue(targetBitwidth);
+    auto smax = APInt::getSignedMaxValue(targetBitwidth);
+    auto umin = APInt::getMinValue(targetBitwidth);
+    auto umax = APInt::getMaxValue(targetBitwidth);
+    if (!checkRange(*range, smin, smax, umin, umax))
+      return failure();
+
+    Type targetType = getTargetType(srcType, targetBitwidth);
+    if (targetType == srcType)
+      return failure();
+
+    Location loc = op->getLoc();
+    IRMapping mapping;
+    for (Value arg : op->getOperands()) {
+      Value newArg = doCast(rewriter, loc, arg, targetType);
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newOp = rewriter.clone(*op, mapping);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+
+private:
+  DataFlowSolver &solver;
+  unsigned targetBitwidth;
+};
+
 struct IntRangeOptimizationsPass
     : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
 
@@ -208,6 +430,32 @@ struct IntRangeOptimizationsPass
       signalPassFailure();
   }
 };
+
+struct IntRangeNarrowingPass
+    : public arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
+  using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<IntegerRangeAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    DataFlowListener listener(solver);
+
+    RewritePatternSet patterns(ctx);
+    populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
+
+    GreedyRewriteConfig config;
+    config.listener = &listener;
+
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
+      signalPassFailure();
+  }
+};
 } // namespace
 
 void mlir::arith::populateIntRangeOptimizationsPatterns(
@@ -216,6 +464,18 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
                DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
 }
 
+void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
+                                                    DataFlowSolver &solver,
+                                                    unsigned targetBitwidth) {
+  // Cmpi uses args ranges instead of results, run it with higher benefit,
+  // as its argumens can be potentially replaced.
+  patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
+                           targetBitwidth);
+
+  patterns.add<NarrowElementwise>(patterns.getContext(), solver,
+                                  targetBitwidth);
+}
+
 std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
   return std::make_unique<IntRangeOptimizationsPass>();
 }
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
new file mode 100644
index 00000000000000..d85cb3384061be
--- /dev/null
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32"  %s | FileCheck %s
+
+// Do not truncate negative values
+// CHECK-LABEL: func @test_addi_neg
+//       CHECK:  %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+//       CHECK:  return %[[RES]] : index
+func.func @test_addi_neg() -> index {
+  %0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = -1 : index, smin = -1 : index, smax = 0 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+// CHECK-LABEL: func @test_addi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.index_castui %[[RES]] : i32 to index
+//       CHECK:  return %[[RES_CASTED]] : index
+func.func @test_addi() -> index {
+  %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
+  %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : index
+  %2 = arith.addi %0, %1 : index
+  return %2 : index
+}
+
+
+// CHECK-LABEL: func @test_addi_i64
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 5 : i64, smin = 4 : i64, umax = 5 : i64, umin = 4 : i64} : i64
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 7 : i64, smin = 6 : i64, umax = 7 : i64, umin = 6 : i64} : i64
+//       CHECK:  %[[A_CASTED:.*]] = arith.trunci %[[A]] : i64 to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.trunci %[[B]] : i64 to i32
+//       CHECK:  %[[RES:.*]] = arith.addi %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  %[[RES_CASTED:.*]] = arith.extui %[[RES]] : i32 to i64
+//       CHECK:  return %[[RES_CASTED]] : i64
+func.func @test_addi_i64() -> i64 {
+  %0 = test.with_bounds { umin = 4 : i64, umax = 5 : i64, smin = 4 : i64, smax = 5 : i64 } : i64
+  %1 = test.with_bounds { umin = 6 : i64, umax = 7 : i64, smin = 6 : i64, smax = 7 : i64 } : i64
+  %2 = arith.addi %0, %1 : i64
+  return %2 : i64
+}
+
+// CHECK-LABEL: func @test_cmpi
+//       CHECK:  %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : index
+//       CHECK:  %[[A_CASTED:.*]] = arith.index_castui %[[A]] : index to i32
+//       CHECK:  %[[B_CASTED:.*]] = arith.index_castui %[[B]] : index to i32
+//       CHECK:  %[[RES:.*]] = arith.cmpi slt, %[[A_CASTED]], %[[B_CASTED]] : i32
+//       CHECK:  return %[[RES]] : i1
+func.func @test_cmpi() -> i1 {
+  %0 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %1 = test.with_bounds { umin = 0 : index, umax = 10 : index, smin = 0 : index, smax = 10 : index } : index
+  %2 = arith.cmpi slt, %0, %1 : index
+  return %2 : i1
+}

>From 696b8a5fc822e9beeee4e7b4e18ac12a27f10c06 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 15 Oct 2024 18:59:07 +0200
Subject: [PATCH 2/2] typo

---
 mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index beeefea497d9b2..ce8e819708b190 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -73,7 +73,7 @@ void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns,
 /// Create a pass which do optimizations based on integer range analysis.
 std::unique_ptr<Pass> createIntRangeOptimizationsPass();
 
-/// Add patterns for int range based norrowing.
+/// Add patterns for int range based narrowing.
 void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
                                        DataFlowSolver &solver,
                                        unsigned targetBitwidth);



More information about the Mlir-commits mailing list