[Mlir-commits] [mlir] 46740dd - [mlir][arith] Add narrowing patterns for subi, divsi, and divui
Jakub Kuderski
llvmlistbot at llvm.org
Tue May 2 07:45:08 PDT 2023
Author: Jakub Kuderski
Date: 2023-05-02T10:44:29-04:00
New Revision: 46740dd02babfc47edd9f8fdb03479ad61223246
URL: https://github.com/llvm/llvm-project/commit/46740dd02babfc47edd9f8fdb03479ad61223246
DIFF: https://github.com/llvm/llvm-project/commit/46740dd02babfc47edd9f8fdb03479ad61223246.diff
LOG: [mlir][arith] Add narrowing patterns for subi, divsi, and divui
Each of these ops is compatible with only one extension kind and
produces an extra result bit.
I checked these transformation in Alive2:
1. subi + extsi: https://alive2.llvm.org/ce/z/ipmZZA
2. divsi + extsi: https://alive2.llvm.org/ce/z/fAcqUv
3. divui + extui: https://alive2.llvm.org/ce/z/QZJpFp
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149531
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/test/Dialect/Arith/int-narrowing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 01507e360c72..cb6e437067be 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -237,6 +237,10 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
/// this, taking into account `BinaryOp` semantics.
virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
+ /// Customization point for patterns that should only apply with
+ /// zero/sign-extension ops as arguments.
+ virtual bool isSupported(ExtensionOp) const { return true; }
+
LogicalResult matchAndRewrite(BinaryOp op,
PatternRewriter &rewriter) const final {
Type origTy = op.getType();
@@ -247,7 +251,7 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
// For the optimization to apply, we expect the lhs to be an extension op,
// and for the rhs to either be the same extension op or a constant.
FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
- if (failed(ext))
+ if (failed(ext) || !isSupported(*ext))
return failure();
FailureOr<unsigned> lhsBitsRequired =
@@ -286,6 +290,27 @@ struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+ // Addition may require one extra bit for the result.
+ // Example: `UINT8_MAX + 1 == 255 + 1 == 256`.
+ unsigned getResultBitsProduced(unsigned operandBits) const override {
+ return operandBits + 1;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// SubIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
+ using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+ // This optimization only applies to signed arguments.
+ bool isSupported(ExtensionOp ext) const override {
+ return ext.getKind() == ExtensionKind::Sign;
+ }
+
+ // Subtraction may require one extra bit for the result.
+ // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`.
unsigned getResultBitsProduced(unsigned operandBits) const override {
return operandBits + 1;
}
@@ -298,11 +323,50 @@ struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+ // Multiplication may require up double the operand bits.
+ // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`.
unsigned getResultBitsProduced(unsigned operandBits) const override {
return 2 * operandBits;
}
};
+//===----------------------------------------------------------------------===//
+// DivSIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
+ using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+ // This optimization only applies to signed arguments.
+ bool isSupported(ExtensionOp ext) const override {
+ return ext.getKind() == ExtensionKind::Sign;
+ }
+
+ // Unlike multiplication, signed division requires only one more result bit.
+ // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`.
+ unsigned getResultBitsProduced(unsigned operandBits) const override {
+ return operandBits + 1;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// DivUIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
+ using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+ // This optimization only applies to unsigned arguments.
+ bool isSupported(ExtensionOp ext) const override {
+ return ext.getKind() == ExtensionKind::Zero;
+ }
+
+ // Unsigned division does not require any extra result bits.
+ unsigned getResultBitsProduced(unsigned operandBits) const override {
+ return operandBits;
+ }
+};
+
//===----------------------------------------------------------------------===//
// *IToFPOp Patterns
//===----------------------------------------------------------------------===//
@@ -625,7 +689,8 @@ void populateArithIntNarrowingPatterns(
ExtensionOverTranspose, ExtensionOverFlatTranspose>(
patterns.getContext(), options, PatternBenefit(2));
- patterns.add<AddIPattern, MulIPattern, SIToFPPattern, UIToFPPattern>(
+ patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
+ DivUIPattern, SIToFPPattern, UIToFPPattern>(
patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index 966e34c779a4..4b155ad86923 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -101,6 +101,75 @@ func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
return %r : vector<3xi32>
}
+//===----------------------------------------------------------------------===//
+// arith.subi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @subi_extsi_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @subi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extsi %rhs : i8 to i32
+ %r = arith.subi %a, %b : i32
+ return %r : i32
+}
+
+// This patterns should only apply to `arith.subi` ops with sign-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @subi_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[SUB]] : i32
+func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.subi %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @subi_mixed_ext_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[ADD]] : i32
+func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.subi %a, %b : i32
+ return %r : i32
+}
+
+// arith.subi produces one more bit of result than the operand bitwidth.
+//
+// CHECK-LABEL: func.func @subi_extsi_i24
+// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @subi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
+ %a = arith.extsi %lhs : i16 to i32
+ %b = arith.extsi %rhs : i16 to i32
+ %r = arith.subi %a, %b : i32
+ return %r : i32
+}
+
//===----------------------------------------------------------------------===//
// arith.muli
//===----------------------------------------------------------------------===//
@@ -183,6 +252,92 @@ func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
return %r : vector<3xi32>
}
+//===----------------------------------------------------------------------===//
+// arith.divsi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @divsi_extsi_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @divsi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extsi %rhs : i8 to i32
+ %r = arith.divsi %a, %b : i32
+ return %r : i32
+}
+
+// This patterns should only apply to `arith.divsi` ops with sign-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @divsi_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[SUB:.+]] = arith.divsi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[SUB]] : i32
+func.func @divsi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.divsi %a, %b : i32
+ return %r : i32
+}
+
+// arith.divsi produces one more bit of result than the operand bitwidth.
+//
+// CHECK-LABEL: func.func @divsi_extsi_i24
+// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT: %[[ADD:.+]] = arith.divsi %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @divsi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
+ %a = arith.extsi %lhs : i16 to i32
+ %b = arith.extsi %rhs : i16 to i32
+ %r = arith.divsi %a, %b : i32
+ return %r : i32
+}
+
+//===----------------------------------------------------------------------===//
+// arith.divui
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @divui_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[ARG0]], %[[ARG1]] : i8
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[SUB]] : i8 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @divui_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.divui %a, %b : i32
+ return %r : i32
+}
+
+// This patterns should only apply to `arith.divui` ops with zero-extended
+// arguments.
+//
+// CHECK-LABEL: func.func @divui_extsi_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[SUB:.+]] = arith.divui %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[SUB]] : i32
+func.func @divui_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extsi %rhs : i8 to i32
+ %r = arith.divui %a, %b : i32
+ return %r : i32
+}
+
//===----------------------------------------------------------------------===//
// arith.*itofp
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list