[Mlir-commits] [mlir] 46740dd - [mlir][arith] Add narrowing patterns for subi, divsi, and divui

Jakub Kuderski llvmlistbot at llvm.org
Tue May 2 07:45:08 PDT 2023


Author: Jakub Kuderski
Date: 2023-05-02T10:44:29-04:00
New Revision: 46740dd02babfc47edd9f8fdb03479ad61223246

URL: https://github.com/llvm/llvm-project/commit/46740dd02babfc47edd9f8fdb03479ad61223246
DIFF: https://github.com/llvm/llvm-project/commit/46740dd02babfc47edd9f8fdb03479ad61223246.diff

LOG: [mlir][arith] Add narrowing patterns for subi, divsi, and divui

Each of these ops is compatible with only one extension kind and
produces an extra result bit.

I checked these transformation in Alive2:
1. subi + extsi: https://alive2.llvm.org/ce/z/ipmZZA
2. divsi + extsi: https://alive2.llvm.org/ce/z/fAcqUv
3. divui + extui: https://alive2.llvm.org/ce/z/QZJpFp

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D149531

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/test/Dialect/Arith/int-narrowing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 01507e360c72..cb6e437067be 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -237,6 +237,10 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
   /// this, taking into account `BinaryOp` semantics.
   virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
 
+  /// Customization point for patterns that should only apply with
+  /// zero/sign-extension ops as arguments.
+  virtual bool isSupported(ExtensionOp) const { return true; }
+
   LogicalResult matchAndRewrite(BinaryOp op,
                                 PatternRewriter &rewriter) const final {
     Type origTy = op.getType();
@@ -247,7 +251,7 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
     // For the optimization to apply, we expect the lhs to be an extension op,
     // and for the rhs to either be the same extension op or a constant.
     FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
-    if (failed(ext))
+    if (failed(ext) || !isSupported(*ext))
       return failure();
 
     FailureOr<unsigned> lhsBitsRequired =
@@ -286,6 +290,27 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
 struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
   using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
 
+  // Addition may require one extra bit for the result.
+  // Example: `UINT8_MAX + 1 == 255 + 1 == 256`.
+  unsigned getResultBitsProduced(unsigned operandBits) const override {
+    return operandBits + 1;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// SubIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
+  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+  // This optimization only applies to signed arguments.
+  bool isSupported(ExtensionOp ext) const override {
+    return ext.getKind() == ExtensionKind::Sign;
+  }
+
+  // Subtraction may require one extra bit for the result.
+  // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`.
   unsigned getResultBitsProduced(unsigned operandBits) const override {
     return operandBits + 1;
   }
@@ -298,11 +323,50 @@ struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
 struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
   using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
 
+  // Multiplication may require up double the operand bits.
+  // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`.
   unsigned getResultBitsProduced(unsigned operandBits) const override {
     return 2 * operandBits;
   }
 };
 
+//===----------------------------------------------------------------------===//
+// DivSIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
+  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+  // This optimization only applies to signed arguments.
+  bool isSupported(ExtensionOp ext) const override {
+    return ext.getKind() == ExtensionKind::Sign;
+  }
+
+  // Unlike multiplication, signed division requires only one more result bit.
+  // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`.
+  unsigned getResultBitsProduced(unsigned operandBits) const override {
+    return operandBits + 1;
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// DivUIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
+  using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+  // This optimization only applies to unsigned arguments.
+  bool isSupported(ExtensionOp ext) const override {
+    return ext.getKind() == ExtensionKind::Zero;
+  }
+
+  // Unsigned division does not require any extra result bits.
+  unsigned getResultBitsProduced(unsigned operandBits) const override {
+    return operandBits;
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // *IToFPOp Patterns
 //===----------------------------------------------------------------------===//
@@ -625,7 +689,8 @@ void populateArithIntNarrowingPatterns(
                ExtensionOverTranspose, ExtensionOverFlatTranspose>(
       patterns.getContext(), options, PatternBenefit(2));
 
-  patterns.add<AddIPattern, MulIPattern, SIToFPPattern, UIToFPPattern>(
+  patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
+               DivUIPattern, SIToFPPattern, UIToFPPattern>(
       patterns.getContext(), options);
 }
 

diff  --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index 966e34c779a4..4b155ad86923 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -101,6 +101,75 @@ func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
   return %r : vector<3xi32>
 }
 
+//===----------------------------------------------------------------------===//
+// arith.subi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @subi_extsi_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[SUB]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @subi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extsi %rhs : i8 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
+// This patterns should only apply to `arith.subi` ops with sign-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @subi_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[SUB]] : i32
+func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @subi_mixed_ext_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[ADD]] : i32
+func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
+// arith.subi produces one more bit of result than the operand bitwidth.
+//
+// CHECK-LABEL: func.func @subi_extsi_i24
+// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.subi %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i24 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @subi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
+  %a = arith.extsi %lhs : i16 to i32
+  %b = arith.extsi %rhs : i16 to i32
+  %r = arith.subi %a, %b : i32
+  return %r : i32
+}
+
 //===----------------------------------------------------------------------===//
 // arith.muli
 //===----------------------------------------------------------------------===//
@@ -183,6 +252,92 @@ func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
   return %r : vector<3xi32>
 }
 
+//===----------------------------------------------------------------------===//
+// arith.divsi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @divsi_extsi_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.divsi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[SUB]] : i16 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @divsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extsi %rhs : i8 to i32
+  %r = arith.divsi %a, %b : i32
+  return %r : i32
+}
+
+// This patterns should only apply to `arith.divsi` ops with sign-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @divsi_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.divsi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[SUB]] : i32
+func.func @divsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.divsi %a, %b : i32
+  return %r : i32
+}
+
+// arith.divsi produces one more bit of result than the operand bitwidth.
+//
+// CHECK-LABEL: func.func @divsi_extsi_i24
+// CHECK-SAME:    (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT:    %[[LHS:.+]]  = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT:    %[[RHS:.+]]  = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT:    %[[ADD:.+]]  = arith.divsi %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extsi %[[ADD]] : i24 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @divsi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
+  %a = arith.extsi %lhs : i16 to i32
+  %b = arith.extsi %rhs : i16 to i32
+  %r = arith.divsi %a, %b : i32
+  return %r : i32
+}
+
+//===----------------------------------------------------------------------===//
+// arith.divui
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @divui_extui_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.divui %[[ARG0]], %[[ARG1]] : i8
+// CHECK-NEXT:    %[[RET:.+]]  = arith.extui %[[SUB]] : i8 to i32
+// CHECK-NEXT:    return %[[RET]] : i32
+func.func @divui_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extui %lhs : i8 to i32
+  %b = arith.extui %rhs : i8 to i32
+  %r = arith.divui %a, %b : i32
+  return %r : i32
+}
+
+// This patterns should only apply to `arith.divui` ops with zero-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @divui_extsi_i8
+// CHECK-SAME:    (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT:    %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT:    %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT:    %[[SUB:.+]]  = arith.divui %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT:    return %[[SUB]] : i32
+func.func @divui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+  %a = arith.extsi %lhs : i8 to i32
+  %b = arith.extsi %rhs : i8 to i32
+  %r = arith.divui %a, %b : i32
+  return %r : i32
+}
+
 //===----------------------------------------------------------------------===//
 // arith.*itofp
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list