[Mlir-commits] [mlir] [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns (PR #112394)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Oct 18 10:41:07 PDT 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/112394
>From 0de9b8c19c716c92c16fdeacd4ef7fac44087d55 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 15 Oct 2024 17:04:56 +0100
Subject: [PATCH] [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern
patterns
Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:
```mlir
%v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
} : vector<3x2xi1> -> vector<3x2xi8>
```
---
.../Transforms/VectorTransferOpTransforms.cpp | 59 +++++++++++++------
...ctor-transfer-drop-unit-dims-patterns.mlir | 48 ++++++++++++++-
2 files changed, 89 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..6cdfa645d78f9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
/// inserting a memref.subview dropping those unit dims. The vector shapes are
/// also reduced accordingly.
class TransferReadDropUnitDimsPattern
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,15 +408,23 @@ class TransferReadDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
- auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+ Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));
+
+ if (maskingOp) {
+ auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+ maskingOp.getMask());
+ newTransferReadOp = mlir::vector::maskOperation(
+ rewriter, newTransferReadOp, shapeCastMask);
+ }
+
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, vectorType, newTransferReadOp);
- rewriter.replaceOp(transferReadOp, shapeCast);
+ loc, vectorType, newTransferReadOp->getResults()[0]);
- return success();
+ return shapeCast;
}
};
@@ -422,11 +432,13 @@ class TransferReadDropUnitDimsPattern
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
/// vector shapes are also reduced accordingly.
class TransferWriteDropUnitDimsPattern
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
- auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+ auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
- identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+ Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+ loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+ maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+ if (maskingOp) {
+ auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+ maskingOp.getMask());
+ newXferWrite =
+ mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+ }
- return success();
+ if (transferWriteOp.hasPureTensorSemantics())
+ return newXferWrite->getResults()[0];
+
+ // With Memref semantics, there's no return value. Use empty value to signal
+ // success.
+ return Value();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..9b61c8ea76f962 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
func.func @transfer_read_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
%c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.mask %mask {
+ vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+ } : vector<3x2xi1> -> vector<3x2xi8>
+ return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
+// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+// CHECK: vector.mask %[[MASK]]
+// CHECK-SAME: vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %vec : vector<3x2xi8>) {
+
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+func.func @transfer_write_rank_reducing_masked(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %vec : vector<3x2xi8>,
+ %mask: vector<3x2xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.mask %mask {
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+ } : vector<3x2xi1>
+ return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
+// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
+// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+// CHECK: vector.mask %[[MASK]]
+// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
func.func @transfer_read_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list