[Mlir-commits] [mlir] 7f3b0e5 - [mlir][arith] Add narrowing patterns to commute more vector ops
Jakub Kuderski
llvmlistbot at llvm.org
Mon May 1 11:34:26 PDT 2023
Author: Jakub Kuderski
Date: 2023-05-01T14:32:57-04:00
New Revision: 7f3b0e584513611bb1d804892eb269ae45d8e715
URL: https://github.com/llvm/llvm-project/commit/7f3b0e584513611bb1d804892eb269ae45d8e715
DIFF: https://github.com/llvm/llvm-project/commit/7f3b0e584513611bb1d804892eb269ae45d8e715.diff
LOG: [mlir][arith] Add narrowing patterns to commute more vector ops
This commutes the extension (`arith.extsi`, `arith.extui`) over the
following vector ops: `vector.broadcast`, `vector.shape_cast`,
`vector.transpose`, `vector.flat_transpose`.
I focused on these as I saw them getting created by vector unroll
patterns. Maybe except `vector.flat_transpose`.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D149534
Added:
Modified:
mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
mlir/test/Dialect/Arith/int-narrowing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 97164621e45c9..0c7afd9255bcd 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -249,6 +249,26 @@ using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
// Patterns to Commute Extension Ops
//===----------------------------------------------------------------------===//
+struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
+ using NarrowingPattern::NarrowingPattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getSource().getDefiningOp());
+ if (failed(ext))
+ return failure();
+
+ VectorType origTy = op.getResultVectorType();
+ VectorType newTy =
+ origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+ Value newBroadcast =
+ rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
+ ext->recreateAndReplace(rewriter, op, newBroadcast);
+ return success();
+ }
+};
+
struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
using NarrowingPattern::NarrowingPattern;
@@ -421,6 +441,68 @@ struct ExtensionOverInsertStridedSlice final
}
};
+struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
+ using NarrowingPattern::NarrowingPattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getSource().getDefiningOp());
+ if (failed(ext))
+ return failure();
+
+ VectorType origTy = op.getResultVectorType();
+ VectorType newTy =
+ origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+ Value newCast =
+ rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
+ ext->recreateAndReplace(rewriter, op, newCast);
+ return success();
+ }
+};
+
+struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
+ using NarrowingPattern::NarrowingPattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getVector().getDefiningOp());
+ if (failed(ext))
+ return failure();
+
+ VectorType origTy = op.getResultVectorType();
+ VectorType newTy =
+ origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+ Value newTranspose = rewriter.create<vector::TransposeOp>(
+ op.getLoc(), newTy, ext->getIn(), op.getTransp());
+ ext->recreateAndReplace(rewriter, op, newTranspose);
+ return success();
+ }
+};
+
+struct ExtensionOverFlatTranspose final
+ : NarrowingPattern<vector::FlatTransposeOp> {
+ using NarrowingPattern::NarrowingPattern;
+
+ LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<ExtensionOp> ext =
+ ExtensionOp::from(op.getMatrix().getDefiningOp());
+ if (failed(ext))
+ return failure();
+
+ VectorType origTy = op.getType();
+ VectorType newTy =
+ origTy.cloneWith(origTy.getShape(), ext->getInElementType());
+ Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
+ op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
+ op.getColumnsAttr());
+ ext->recreateAndReplace(rewriter, op, newTranspose);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pass Definitions
//===----------------------------------------------------------------------===//
@@ -449,9 +531,11 @@ void populateArithIntNarrowingPatterns(
RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
// Add commute patterns with a higher benefit. This is to expose more
// optimization opportunities to narrowing patterns.
- patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
- ExtensionOverExtractStridedSlice, ExtensionOverInsert,
- ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
+ patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
+ ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
+ ExtensionOverInsert, ExtensionOverInsertElement,
+ ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
+ ExtensionOverTranspose, ExtensionOverFlatTranspose>(
patterns.getContext(), options, PatternBenefit(2));
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);
diff --git a/mlir/test/Dialect/Arith/int-narrowing.mlir b/mlir/test/Dialect/Arith/int-narrowing.mlir
index 6d5299c2f00da..675a52b5d53e6 100644
--- a/mlir/test/Dialect/Arith/int-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-narrowing.mlir
@@ -442,3 +442,91 @@ func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<
%e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
return %e : vector<2x3xi32>
}
+
+// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16
+// CHECK-SAME: (%[[ARG:.+]]: i16)
+// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3xi32>
+func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> {
+ %b = arith.extsi %a : i16 to i32
+ %r = vector.broadcast %b : i32 to vector<3xi32>
+ return %r : vector<3xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
+// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2x3xi32>
+func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> {
+ %b = arith.extui %a : vector<3xi16> to vector<3xi32>
+ %r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32>
+ return %r : vector<2x3xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3x2xi32>
+func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
+ %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
+ %r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32>
+ return %r : vector<3x2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>)
+// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32>
+func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
+ %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
+ %r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32>
+ return %r : vector<2x3x5xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
+// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32>
+// CHECK-NEXT: return %[[RET]] : vector<3x2xi32>
+func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
+ %b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
+ %r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+ return %r : vector<3x2xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>)
+// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32>
+// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32>
+func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
+ %b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
+ %r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32>
+ return %r : vector<2x3x5xi32>
+}
+
+// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>)
+// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32>
+// CHECK-NEXT: return %[[RET]] : vector<16xi32>
+func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
+ %b = arith.extsi %a : vector<16xi16> to vector<16xi32>
+ %r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32>
+ return %r : vector<16xi32>
+}
+
+// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16
+// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>)
+// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16>
+// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32>
+// CHECK-NEXT: return %[[RET]] : vector<16xi32>
+func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
+ %b = arith.extui %a : vector<16xi16> to vector<16xi32>
+ %r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
+ return %r : vector<16xi32>
+}
More information about the Mlir-commits
mailing list