[Mlir-commits] [mlir] e2f7563 - [mlir][arith] Add narrowing patterns for `addi` and `muli`

Jakub Kuderski llvmlistbot at llvm.org
Tue May 2 07:11:53 PDT 2023


Author: Jakub Kuderski
Date: 2023-05-02T10:10:10-04:00
New Revision: e2f7563d7c30c3aca3ac8b937a4967accb59c209

URL: https://github.com/llvm/llvm-project/commit/e2f7563d7c30c3aca3ac8b937a4967accb59c209
DIFF: https://github.com/llvm/llvm-project/commit/e2f7563d7c30c3aca3ac8b937a4967accb59c209.diff

LOG: [mlir][arith] Add narrowing patterns for `addi` and `muli`

These two ops are handled in a very similar way -- the only difference
in the number result bits produced.

I checked these transformation with Alive2:
1.  addi + sext: https://alive2.llvm.org/ce/z/3NSs9T
2.  addi + zext: https://alive2.llvm.org/ce/z/t7XHOT
3.  muli + sext: https://alive2.llvm.org/ce/z/-7sfW9
4.  muli + zext: https://alive2.llvm.org/ce/z/h4yntF

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149530

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/test/Dialect/Arith/int-narrowing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 0c7afd9255bcd..01507e360c722 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -216,6 +216,93 @@ FailureOr<unsigned> calculateBitsRequired(Value value,
   return calculateBitsRequired(value.getType());
 }
 
+/// Base pattern for arith binary ops.
+/// Example:
+/// ```
+///   %lhs = arith.extsi %a : i8 to i32
+///   %rhs = arith.extsi %b : i8 to i32
+///   %r = arith.addi %lhs, %rhs : i32
+/// ==>
+///   %lhs = arith.extsi %a : i8 to i16
+///   %rhs = arith.extsi %b : i8 to i16
+///   %add = arith.addi %lhs, %rhs : i16
+///   %r = arith.extsi %add : i16 to i32
+/// ```
+template <typename BinaryOp>
+struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
+  using NarrowingPattern<BinaryOp>::NarrowingPattern;
+
+  /// Returns the number of bits required to represent the full result, assuming
+  /// that both operands are `operandBits`-wide. Derived classes must implement
+  /// this, taking into account `BinaryOp` semantics.
+  virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
+
+  LogicalResult matchAndRewrite(BinaryOp op,
+                                PatternRewriter &rewriter) const final {
+    Type origTy = op.getType();
+    FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
+    if (failed(resultBits))
+      return failure();
+
+    // For the optimization to apply, we expect the lhs to be an extension op,
+    // and for the rhs to either be the same extension op or a constant.
+    FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
+    if (failed(ext))
+      return failure();
+
+    FailureOr<unsigned> lhsBitsRequired =
+        calculateBitsRequired(ext->getIn(), ext->getKind());
+    if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
+      return failure();
+
+    FailureOr<unsigned> rhsBitsRequired =
+        calculateBitsRequired(op.getRhs(), ext->getKind());
+    if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
+      return failure();
+
+    // Negotiate a common bit requirements for both lhs and rhs, accounting for
+    // the result requiring more bits than the operands.
+    unsigned commonBitsRequired =
+        getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
+    FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
+    if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
+      return failure();
+
+    Location loc = op.getLoc();
+    Value newLhs =
+        rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
+    Value newRhs =
+        rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
+    Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
+    ext->recreateAndReplace(rewriter, op, newAdd);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// AddIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
+  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+  unsigned getResultBitsProduced(unsigned operandBits) const override {
+    return operandBits + 1;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// MulIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
+  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+  unsigned getResultBitsProduced(unsigned operandBits) const override {
+    return 2 * operandBits;
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // *IToFPOp Patterns
 //===----------------------------------------------------------------------===//
@@ -538,7 +625,8 @@ void populateArithIntNarrowingPatterns(
                ExtensionOverTranspose, ExtensionOverFlatTranspose>(
       patterns.getContext(), options, PatternBenefit(2));
 
-  patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
+  patterns.add<AddIPattern, MulIPattern, SIToFPPattern, UIToFPPattern>(
+      patterns.getContext(), options);
 }
 
 } // namespace mlir::arith

diff  --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index 675a52b5d53e6..966e34c779a49 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -1,6 +1,188 @@
-// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
+// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,24,32" \
 // RUN:          --verify-diagnostics %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// arith.addi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @addi_extsi_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extsi %rhs : i8 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// CHECK-LABEL: func.func @addi_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[ADD]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// arith.addi produces one more bit of result than the operand bitwidth.
+//
+// CHECK-LABEL: func.func @addi_extsi_i24
+// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i24 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @addi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
+  %a = arith.extsi %lhs : i16 to i32
+  %b = arith.extsi %rhs : i16 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @addi_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[ADD]] : i32
+func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.addi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because we cannot reduce the bitwidth
+// below i16, given the pass options set.
+//
+// CHECK-LABEL: func.func @addi_extsi_i16
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[EXT0]], %[[EXT1]] : i16
+// CHECK-NEXT:    return %[[ADD]] : i16
+func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
+  %a = arith.extsi %lhs : i8 to i16
+  %b = arith.extsi %rhs : i8 to i16
+  %r = arith.addi %a, %b : i16
+  return %r : i16
+}
+
+// CHECK-LABEL: func.func @addi_extsi_3xi8_cst
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi8>)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
+// CHECK-NEXT:    %[[EXT:.+]]  = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.addi %[[LHS]], %[[CST]] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
+  %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
+  %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
+  %r = arith.addi %a, %cst : vector<3xi32>
+  return %r : vector<3xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arith.muli
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @muli_extsi_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[MUL]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @muli_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extsi %rhs : i8 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// CHECK-LABEL: func.func @muli_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[MUL]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// We do not expect this case to be optimized because given n-bit operands,
+// arith.muli produces 2n bits of result.
+//
+// CHECK-LABEL: func.func @muli_extsi_i32
+// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT:    %[[RET:.+]]  = arith.muli %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
+  %a = arith.extsi %lhs : i16 to i32
+  %b = arith.extsi %rhs : i16 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @muli_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[MUL]] : i32
+func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.muli %a, %b : i32
+  return %r : i32
+}
+
+// CHECK-LABEL: func.func @muli_extsi_3xi8_cst
+// CHECK-SAME:    (%[[ARG0:.+]]: vector<3xi8>)
+// CHECK-NEXT:    %[[CST:.+]]  = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
+// CHECK-NEXT:    %[[EXT:.+]]  = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
+// CHECK-NEXT:    %[[MUL:.+]]  = arith.muli %[[LHS]], %[[CST]] : vector<3xi16>
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[MUL]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT:    return %[[RET]] : vector<3xi32>
+func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
+  %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
+  %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
+  %r = arith.muli %a, %cst : vector<3xi32>
+  return %r : vector<3xi32>
+}
+
 //===----------------------------------------------------------------------===//
 // arith.*itofp
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list