[Mlir-commits] [mlir] 3ff8708 - [mlir][arith] Add narrowing patterns for other insertion ops
Jakub Kuderski
llvmlistbot at llvm.org
Mon May 1 11:30:34 PDT 2023
Author: Jakub Kuderski
Date: 2023-05-01T14:29:02-04:00
New Revision: 3ff870881f5f0d3d08753efd558ac5f05d04a574
URL: https://github.com/llvm/llvm-project/commit/3ff870881f5f0d3d08753efd558ac5f05d04a574
DIFF: https://github.com/llvm/llvm-project/commit/3ff870881f5f0d3d08753efd558ac5f05d04a574.diff
LOG: [mlir][arith] Add narrowing patterns for other insertion ops
Allow to commute extension ops over `vector.insertelement` and
`vector.insert_strided_slice`.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149509
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/test/Dialect/Arith/int-narrowing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index c515824a8e04d..97164621e45c9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -306,27 +306,35 @@ struct ExtensionOverExtractStridedSlice final
}
};
-struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
- using NarrowingPattern::NarrowingPattern;
-
- LogicalResult matchAndRewrite(vector::InsertOp op,
- PatternRewriter &rewriter) const override {
+/// Base pattern for `vector.insert` narrowing patterns.
+template <typename InsertionOp>
+struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
+ using NarrowingPattern<InsertionOp>::NarrowingPattern;
+
+ /// Derived classes must provide a function to create the matching insertion
+ /// op based on the original op and new arguments.
+ virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
+ InsertionOp origInsert,
+ Value narrowValue,
+ Value narrowDest) const = 0;
+
+ LogicalResult matchAndRewrite(InsertionOp op,
+ PatternRewriter &rewriter) const final {
FailureOr<ExtensionOp> ext =
ExtensionOp::from(op.getSource().getDefiningOp());
if (failed(ext))
return failure();
- FailureOr<vector::InsertOp> newInsert =
- createNarrowInsert(op, rewriter, *ext);
+ FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
if (failed(newInsert))
return failure();
ext->recreateAndReplace(rewriter, op, *newInsert);
return success();
}
- FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
- PatternRewriter &rewriter,
- ExtensionOp insValue) const {
+ FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
+ PatternRewriter &rewriter,
+ ExtensionOp insValue) const {
// Calculate the operand and result bitwidths. We can only apply narrowing
// when the inserted source value and destination vector require fewer bits
// than the result. Because the source and destination may have
diff erent
@@ -337,6 +345,8 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
if (failed(origBitsRequired))
return failure();
+ // TODO: We could relax this check by disregarding bitwidth requirements of
+ // elements that we know will be replaced by the insertion.
FailureOr<unsigned> destBitsRequired =
calculateBitsRequired(op.getDest(), insValue.getKind());
if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
@@ -352,12 +362,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
// both the source and the destination values.
unsigned newInsertionBits =
std::max(*destBitsRequired, *insertedBitsRequired);
- FailureOr<Type> newVecTy = getNarrowType(newInsertionBits, op.getType());
+ FailureOr<Type> newVecTy =
+ this->getNarrowType(newInsertionBits, op.getType());
if (failed(newVecTy) || *newVecTy == op.getType())
return failure();
FailureOr<Type> newInsertedValueTy =
- getNarrowType(newInsertionBits, insValue.getType());
+ this->getNarrowType(newInsertionBits, insValue.getType());
if (failed(newInsertedValueTy))
return failure();
@@ -366,8 +377,47 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
loc, *newInsertedValueTy, insValue.getResult());
Value narrowDest =
rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
- return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,
- op.getPosition());
+ return createInsertionOp(rewriter, op, narrowValue, narrowDest);
+ }
+};
+
+struct ExtensionOverInsert final
+ : ExtensionOverInsertionPattern<vector::InsertOp> {
+ using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
+
+ vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
+ vector::InsertOp origInsert,
+ Value narrowValue,
+ Value narrowDest) const override {
+ return rewriter.create<vector::InsertOp>(
+ origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
+ }
+};
+
+struct ExtensionOverInsertElement final
+ : ExtensionOverInsertionPattern<vector::InsertElementOp> {
+ using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
+
+ vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
+ vector::InsertElementOp origInsert,
+ Value narrowValue,
+ Value narrowDest) const override {
+ return rewriter.create<vector::InsertElementOp>(
+ origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
+ }
+};
+
+struct ExtensionOverInsertStridedSlice final
+ : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
+ using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
+
+ vector::InsertStridedSliceOp
+ createInsertionOp(PatternRewriter &rewriter,
+ vector::InsertStridedSliceOp origInsert, Value narrowValue,
+ Value narrowDest) const override {
+ return rewriter.create<vector::InsertStridedSliceOp>(
+ origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
+ origInsert.getStrides());
}
};
@@ -400,7 +450,8 @@ void populateArithIntNarrowingPatterns(
// Add commute patterns with a higher benefit. This is to expose more
// optimization opportunities to narrowing patterns.
patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
- ExtensionOverExtractStridedSlice, ExtensionOverInsert>(
+ ExtensionOverExtractStridedSlice, ExtensionOverInsert,
+ ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
patterns.getContext(), options, PatternBenefit(2));
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index d98e03d93b030..6d5299c2f00da 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -328,3 +328,117 @@ func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
%e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
return %e : vector<3xi32>
}
+
+// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
+ %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %d = arith.extsi %b : i16 to i32
+ %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insertelement_3xi16
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extui_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
+ %c = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %d = arith.extui %b : i16 to i32
+ %e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16_cst_i16
+// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
+// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
+// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
+ %cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
+ %d = arith.extsi %a : i8 to i32
+ %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insertelement_3xi16_cst_i16
+// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 256, 0]> : vector<3xi16>
+// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
+// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
+// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extui_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
+ %cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
+ %d = arith.extui %a : i8 to i32
+ %e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
+// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
+ %c = arith.extsi %a : vector<3xi16> to vector<3xi32>
+ %d = arith.extsi %b : vector<2xi16> to vector<2xi32>
+ %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_strided_slice_1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
+// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extui_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
+ %c = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %d = arith.extui %b : vector<2xi16> to vector<2xi32>
+ %e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
+ return %e : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_cst_2d
+// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi16>
+// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
+// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
+// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
+// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2x3xi32>
+func.func @extsi_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
+ %cst = arith.constant dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi32>
+ %d = arith.extsi %a : vector<1x2xi8> to vector<1x2xi32>
+ %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
+ return %e : vector<2x3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_insert_strided_slice_cst_2d
+// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>)
+// CHECK-NEXT: %[[CST:.+]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi16>
+// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
+// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
+// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
+// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2x3xi32>
+func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
+ %cst = arith.constant dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi32>
+ %d = arith.extui %a : vector<1x2xi8> to vector<1x2xi32>
+ %e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
+ return %e : vector<2x3xi32>
+}
More information about the Mlir-commits
mailing list