[Mlir-commits] [mlir] 33017e5 - [mlir][arith] Add narrowing pattern to commute extension over insertion
Jakub Kuderski
llvmlistbot at llvm.org
Fri Apr 28 13:19:12 PDT 2023
Author: Jakub Kuderski
Date: 2023-04-28T16:17:44-04:00
New Revision: 33017e5a3fa2c3194522565cd0e106a931b072b3
URL: https://github.com/llvm/llvm-project/commit/33017e5a3fa2c3194522565cd0e106a931b072b3
DIFF: https://github.com/llvm/llvm-project/commit/33017e5a3fa2c3194522565cd0e106a931b072b3.diff
LOG: [mlir][arith] Add narrowing pattern to commute extension over insertion
This enabled more optimization opportunities by moving
zero/sign-extension closer to the use.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149282
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/test/Dialect/Arith/int-narrowing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 3401a9c05b632..639b19b0a5d8a 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -70,7 +70,7 @@ struct NarrowingPattern : OpRewritePattern<SourceOp> {
if (!isa<IntegerType>(elemTy))
return failure();
- auto newElemTy = IntegerType::get(origTy.getContext(), bitsRequired);
+ auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
if (newElemTy == elemTy)
return failure();
@@ -100,11 +100,58 @@ FailureOr<unsigned> calculateBitsRequired(Type type) {
enum class ExtensionKind { Sign, Zero };
+ExtensionKind getExtensionKind(Operation *op) {
+ assert(op);
+ assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
+ return isa<arith::ExtSIOp>(op) ? ExtensionKind::Sign : ExtensionKind::Zero;
+}
+
+/// Returns the integer bitwidth required to represent `value`.
+unsigned calculateBitsRequired(const APInt &value,
+ ExtensionKind lookThroughExtension) {
+ // For unsigned values, we only need the active bits. As a special case, zero
+ // requires one bit.
+ if (lookThroughExtension == ExtensionKind::Zero)
+ return std::max(value.getActiveBits(), 1u);
+
+ // If a signed value is nonnegative, we need one extra bit for the sign.
+ if (value.isNonNegative())
+ return value.getActiveBits() + 1;
+
+ // For the signed min, we need all the bits.
+ if (value.isMinSignedValue())
+ return value.getBitWidth();
+
+ // For negative values, we need all the non-sign bits and one extra bit for
+ // the sign.
+ return value.getBitWidth() - value.getNumSignBits() + 1;
+}
+
/// Returns the integer bitwidth required to represent `value`.
/// Looks through either sign- or zero-extension as specified by
/// `lookThroughExtension`.
FailureOr<unsigned> calculateBitsRequired(Value value,
ExtensionKind lookThroughExtension) {
+ // Handle constants.
+ if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
+
+ if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
+ if (elemsAttr.getElementType().isIntOrIndex()) {
+ if (elemsAttr.isSplat())
+ return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
+ lookThroughExtension);
+
+ unsigned maxBits = 1;
+ for (const APInt &elemValue : elemsAttr.getValues<APInt>())
+ maxBits = std::max(
+ maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
+ return maxBits;
+ }
+ }
+ }
+
if (lookThroughExtension == ExtensionKind::Sign) {
if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
return calculateBitsRequired(sext.getIn().getType());
@@ -150,8 +197,8 @@ using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
// Patterns to Commute Extension Ops
//===----------------------------------------------------------------------===//
-struct ExtensionOverExtract final : OpRewritePattern<vector::ExtractOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
+ using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
@@ -172,8 +219,8 @@ struct ExtensionOverExtract final : OpRewritePattern<vector::ExtractOp> {
};
struct ExtensionOverExtractElement final
- : OpRewritePattern<vector::ExtractElementOp> {
- using OpRewritePattern::OpRewritePattern;
+ : NarrowingPattern<vector::ExtractElementOp> {
+ using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::ExtractElementOp op,
PatternRewriter &rewriter) const override {
@@ -194,8 +241,8 @@ struct ExtensionOverExtractElement final
};
struct ExtensionOverExtractStridedSlice final
- : OpRewritePattern<vector::ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ : NarrowingPattern<vector::ExtractStridedSliceOp> {
+ using NarrowingPattern::NarrowingPattern;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -220,6 +267,80 @@ struct ExtensionOverExtractStridedSlice final
}
};
+struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
+ using NarrowingPattern::NarrowingPattern;
+
+ LogicalResult matchAndRewrite(vector::InsertOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *def = op.getSource().getDefiningOp();
+ if (!def)
+ return failure();
+
+ return TypeSwitch<Operation *, LogicalResult>(def)
+ .Case<arith::ExtSIOp, arith::ExtUIOp>([&](auto extOp) {
+ // Rewrite the insertion in terms of narrower operands
+ // and later extend the result to the original bitwidth.
+ FailureOr<vector::InsertOp> newInsert =
+ createNarrowInsert(op, rewriter, extOp);
+ if (failed(newInsert))
+ return failure();
+ rewriter.replaceOpWithNewOp<decltype(extOp)>(op, op.getType(),
+ *newInsert);
+ return success();
+ })
+ .Default(failure());
+ }
+
+ FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
+ PatternRewriter &rewriter,
+ Operation *insValue) const {
+ assert((isa<arith::ExtSIOp, arith::ExtUIOp>(insValue)));
+
+ // Calculate the operand and result bitwidths. We can only apply narrowing
+ // when the inserted source value and destination vector require fewer bits
+ // than the result. Because the source and destination may have
diff erent
+ // bitwidths requirements, we have to find the common narrow bitwidth that
+ // is greater equal to the operand bitwidth requirements and still narrower
+ // than the result.
+ FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
+ if (failed(origBitsRequired))
+ return failure();
+
+ ExtensionKind kind = getExtensionKind(insValue);
+ FailureOr<unsigned> destBitsRequired =
+ calculateBitsRequired(op.getDest(), kind);
+ if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
+ return failure();
+
+ FailureOr<unsigned> insertedBitsRequired =
+ calculateBitsRequired(insValue->getOperands().front(), kind);
+ if (failed(insertedBitsRequired) ||
+ *insertedBitsRequired >= *origBitsRequired)
+ return failure();
+
+ // Find a narrower element type that satisfies the bitwidth requirements of
+ // both the source and the destination values.
+ unsigned newInsertionBits =
+ std::max(*destBitsRequired, *insertedBitsRequired);
+ FailureOr<Type> newVecTy = getNarrowType(newInsertionBits, op.getType());
+ if (failed(newVecTy) || *newVecTy == op.getType())
+ return failure();
+
+ FailureOr<Type> newInsertedValueTy =
+ getNarrowType(newInsertionBits, insValue->getResultTypes().front());
+ if (failed(newInsertedValueTy))
+ return failure();
+
+ Location loc = op.getLoc();
+ Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
+ loc, *newInsertedValueTy, insValue->getResult(0));
+ Value narrowDest =
+ rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
+ return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,
+ op.getPosition());
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pass Definitions
//===----------------------------------------------------------------------===//
@@ -249,8 +370,8 @@ void populateArithIntNarrowingPatterns(
// Add commute patterns with a higher benefit. This is to expose more
// optimization opportunities to narrowing patterns.
patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
- ExtensionOverExtractStridedSlice>(patterns.getContext(),
- PatternBenefit(2));
+ ExtensionOverExtractStridedSlice, ExtensionOverInsert>(
+ patterns.getContext(), options, PatternBenefit(2));
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index f1290e552fd77..d98e03d93b030 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -235,3 +235,96 @@ func.func @extui_over_extract_strided_slice_2d(%a: vector<2x3xi16>) -> vector<1x
{offsets = [1, 1], sizes = [1, 2], strides = [1, 1]} : vector<2x3xi32> to vector<1x2xi32>
return %c : vector<1x2xi32>
}
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi16
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16)
+// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> {
+ %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %d = arith.extsi %b : i16 to i32
+ %e = vector.insert %d, %c [1] : i32 into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_3xi16
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16)
+// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG1]], %[[ARG0]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_3xi16(%a: vector<3xi16>, %b: i16) -> vector<3xi32> {
+ %c = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %d = arith.extui %b : i16 to i32
+ %e = vector.insert %d, %c [1] : i32 into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_0
+// CHECK-SAME: (%[[ARG:.+]]: i16)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<0> : vector<3xi16>
+// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi16_cst_0(%a: i16) -> vector<3xi32> {
+ %cst = arith.constant dense<0> : vector<3xi32>
+ %d = arith.extsi %a : i16 to i32
+ %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi8_cst
+// CHECK-SAME: (%[[ARG:.+]]: i8)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 127, -128]> : vector<3xi8>
+// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> {
+ %cst = arith.constant dense<[-1, 127, -128]> : vector<3xi32>
+ %d = arith.extsi %a : i8 to i32
+ %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_3xi8_cst
+// CHECK-SAME: (%[[ARG:.+]]: i8)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 127, -1]> : vector<3xi8>
+// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[ARG]], %[[CST]] [1] : i8 into vector<3xi8>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi8> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_3xi8_cst(%a: i8) -> vector<3xi32> {
+ %cst = arith.constant dense<[1, 127, 255]> : vector<3xi32>
+ %d = arith.extui %a : i8 to i32
+ %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_3xi16_cst_i16
+// CHECK-SAME: (%[[ARG:.+]]: i8)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
+// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
+// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
+ %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
+ %d = arith.extsi %a : i8 to i32
+ %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_3xi16_cst_i16
+// CHECK-SAME: (%[[ARG:.+]]: i8)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 256, 0]> : vector<3xi16>
+// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
+// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT: %[[INS:.+]] = vector.insert %[[SRCT]], %[[CST]] [1] : i16 into vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
+ %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
+ %d = arith.extui %a : i8 to i32
+ %e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
+ return %e : vector<3xi32>
+}
More information about the Mlir-commits
mailing list