[Mlir-commits] [mlir] e2f7563 - [mlir][arith] Add narrowing patterns for `addi` and `muli`
Jakub Kuderski
llvmlistbot at llvm.org
Tue May 2 07:11:53 PDT 2023
Author: Jakub Kuderski
Date: 2023-05-02T10:10:10-04:00
New Revision: e2f7563d7c30c3aca3ac8b937a4967accb59c209
URL: https://github.com/llvm/llvm-project/commit/e2f7563d7c30c3aca3ac8b937a4967accb59c209
DIFF: https://github.com/llvm/llvm-project/commit/e2f7563d7c30c3aca3ac8b937a4967accb59c209.diff
LOG: [mlir][arith] Add narrowing patterns for `addi` and `muli`
These two ops are handled in a very similar way -- the only difference
in the number result bits produced.
I checked these transformation with Alive2:
1. addi + sext: https://alive2.llvm.org/ce/z/3NSs9T
2. addi + zext: https://alive2.llvm.org/ce/z/t7XHOT
3. muli + sext: https://alive2.llvm.org/ce/z/-7sfW9
4. muli + zext: https://alive2.llvm.org/ce/z/h4yntF
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149530
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/test/Dialect/Arith/int-narrowing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 0c7afd9255bcd..01507e360c722 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -216,6 +216,93 @@ FailureOr<unsigned> calculateBitsRequired(Value value,
return calculateBitsRequired(value.getType());
}
+/// Base pattern for arith binary ops.
+/// Example:
+/// ```
+/// %lhs = arith.extsi %a : i8 to i32
+/// %rhs = arith.extsi %b : i8 to i32
+/// %r = arith.addi %lhs, %rhs : i32
+/// ==>
+/// %lhs = arith.extsi %a : i8 to i16
+/// %rhs = arith.extsi %b : i8 to i16
+/// %add = arith.addi %lhs, %rhs : i16
+/// %r = arith.extsi %add : i16 to i32
+/// ```
+template <typename BinaryOp>
+struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
+ using NarrowingPattern<BinaryOp>::NarrowingPattern;
+
+ /// Returns the number of bits required to represent the full result, assuming
+ /// that both operands are `operandBits`-wide. Derived classes must implement
+ /// this, taking into account `BinaryOp` semantics.
+ virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
+
+ LogicalResult matchAndRewrite(BinaryOp op,
+ PatternRewriter &rewriter) const final {
+ Type origTy = op.getType();
+ FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
+ if (failed(resultBits))
+ return failure();
+
+ // For the optimization to apply, we expect the lhs to be an extension op,
+ // and for the rhs to either be the same extension op or a constant.
+ FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
+ if (failed(ext))
+ return failure();
+
+ FailureOr<unsigned> lhsBitsRequired =
+ calculateBitsRequired(ext->getIn(), ext->getKind());
+ if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
+ return failure();
+
+ FailureOr<unsigned> rhsBitsRequired =
+ calculateBitsRequired(op.getRhs(), ext->getKind());
+ if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
+ return failure();
+
+ // Negotiate a common bit requirements for both lhs and rhs, accounting for
+ // the result requiring more bits than the operands.
+ unsigned commonBitsRequired =
+ getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
+ FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
+ if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
+ return failure();
+
+ Location loc = op.getLoc();
+ Value newLhs =
+ rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
+ Value newRhs =
+ rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
+ Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
+ ext->recreateAndReplace(rewriter, op, newAdd);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// AddIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
+ using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+ unsigned getResultBitsProduced(unsigned operandBits) const override {
+ return operandBits + 1;
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// MulIOp Pattern
+//===----------------------------------------------------------------------===//
+
+struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
+ using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
+
+ unsigned getResultBitsProduced(unsigned operandBits) const override {
+ return 2 * operandBits;
+ }
+};
+
//===----------------------------------------------------------------------===//
// *IToFPOp Patterns
//===----------------------------------------------------------------------===//
@@ -538,7 +625,8 @@ void populateArithIntNarrowingPatterns(
ExtensionOverTranspose, ExtensionOverFlatTranspose>(
patterns.getContext(), options, PatternBenefit(2));
- patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
+ patterns.add<AddIPattern, MulIPattern, SIToFPPattern, UIToFPPattern>(
+ patterns.getContext(), options);
}
} // namespace mlir::arith
diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index 675a52b5d53e6..966e34c779a49 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -1,6 +1,188 @@
-// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,32" \
+// RUN: mlir-opt --arith-int-narrowing="int-bitwidths-supported=1,8,16,24,32" \
// RUN: --verify-diagnostics %s | FileCheck %s
+//===----------------------------------------------------------------------===//
+// arith.addi
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @addi_extsi_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @addi_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extsi %rhs : i8 to i32
+ %r = arith.addi %a, %b : i32
+ return %r : i32
+}
+
+// CHECK-LABEL: func.func @addi_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[ADD]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.addi %a, %b : i32
+ return %r : i32
+}
+
+// arith.addi produces one more bit of result than the operand bitwidth.
+//
+// CHECK-LABEL: func.func @addi_extsi_i24
+// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i24
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i24
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i24
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i24 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @addi_extsi_i24(%lhs: i16, %rhs: i16) -> i32 {
+ %a = arith.extsi %lhs : i16 to i32
+ %b = arith.extsi %rhs : i16 to i32
+ %r = arith.addi %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @addi_mixed_ext_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[ADD]] : i32
+func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.addi %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because we cannot reduce the bitwidth
+// below i16, given the pass options set.
+//
+// CHECK-LABEL: func.func @addi_extsi_i16
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i16
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i16
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i16
+// CHECK-NEXT: return %[[ADD]] : i16
+func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
+ %a = arith.extsi %lhs : i8 to i16
+ %b = arith.extsi %rhs : i8 to i16
+ %r = arith.addi %a, %b : i16
+ return %r : i16
+}
+
+// CHECK-LABEL: func.func @addi_extsi_3xi8_cst
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
+// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[CST]] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @addi_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
+ %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
+ %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
+ %r = arith.addi %a, %cst : vector<3xi32>
+ return %r : vector<3xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arith.muli
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @muli_extsi_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @muli_extsi_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extsi %rhs : i8 to i32
+ %r = arith.muli %a, %b : i32
+ return %r : i32
+}
+
+// CHECK-LABEL: func.func @muli_extui_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[MUL]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @muli_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extui %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.muli %a, %b : i32
+ return %r : i32
+}
+
+// We do not expect this case to be optimized because given n-bit operands,
+// arith.muli produces 2n bits of result.
+//
+// CHECK-LABEL: func.func @muli_extsi_i32
+// CHECK-SAME: (%[[ARG0:.+]]: i16, %[[ARG1:.+]]: i16)
+// CHECK-NEXT: %[[LHS:.+]] = arith.extsi %[[ARG0]] : i16 to i32
+// CHECK-NEXT: %[[RHS:.+]] = arith.extsi %[[ARG1]] : i16 to i32
+// CHECK-NEXT: %[[RET:.+]] = arith.muli %[[LHS]], %[[RHS]] : i32
+// CHECK-NEXT: return %[[RET]] : i32
+func.func @muli_extsi_i32(%lhs: i16, %rhs: i16) -> i32 {
+ %a = arith.extsi %lhs : i16 to i32
+ %b = arith.extsi %rhs : i16 to i32
+ %r = arith.muli %a, %b : i32
+ return %r : i32
+}
+
+// This case should not get optimized because of mixed extensions.
+//
+// CHECK-LABEL: func.func @muli_mixed_ext_i8
+// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
+// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
+// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
+// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32
+// CHECK-NEXT: return %[[MUL]] : i32
+func.func @muli_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
+ %a = arith.extsi %lhs : i8 to i32
+ %b = arith.extui %rhs : i8 to i32
+ %r = arith.muli %a, %b : i32
+ return %r : i32
+}
+
+// CHECK-LABEL: func.func @muli_extsi_3xi8_cst
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi8>)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, 42]> : vector<3xi16>
+// CHECK-NEXT: %[[EXT:.+]] = arith.extsi %[[ARG0]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT]] : vector<3xi32> to vector<3xi16>
+// CHECK-NEXT: %[[MUL:.+]] = arith.muli %[[LHS]], %[[CST]] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[MUL]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @muli_extsi_3xi8_cst(%lhs: vector<3xi8>) -> vector<3xi32> {
+ %cst = arith.constant dense<[-1, 127, 42]> : vector<3xi32>
+ %a = arith.extsi %lhs : vector<3xi8> to vector<3xi32>
+ %r = arith.muli %a, %cst : vector<3xi32>
+ return %r : vector<3xi32>
+}
+
//===----------------------------------------------------------------------===//
// arith.*itofp
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list