[Mlir-commits] [mlir] [MLIR][Vector] Implement transferXXPermutationLowering as MaskableOpRewritePattern (PR #91987)
Hugo Trachino
llvmlistbot at llvm.org
Mon May 20 07:16:48 PDT 2024
https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/91987
>From 2f7fc4bcc667aeea40357dfe1ac5d12a2162f61b Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Fri, 17 May 2024 18:49:31 +0800
Subject: [PATCH 1/8] [mlir][vector] Support MaskableOpRewritePattern of Op
without a result.
---
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 030be328e97fd..bf9694556f901 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -157,7 +157,10 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
+ if (rootOp->getNumResults() == 0 || *newOp == Value())
+ rewriter.eraseOp(rootOp);
+ else
+ rewriter.replaceOp(rootOp, *newOp);
return success();
}
>From b6a41bc31d9779f3b0744c3f168dec171207d0f6 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Tue, 14 May 2024 22:21:23 +0800
Subject: [PATCH 2/8] [MLIR][Vector] Implement transferXXPermutationLowering as
MaskableOpRewritePattern
---
.../Vector/Transforms/LowerVectorTransfer.cpp | 64 +++++++++------
.../vector-transfer-permutation-lowering.mlir | 82 +++++++++++++++++++
2 files changed, 122 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index b30b43d70bf0f..7f5703b635068 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -90,14 +90,19 @@ namespace {
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
struct TransferReadPermutationLowering
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_read inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -142,9 +147,9 @@ struct TransferReadPermutationLowering
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
- transposePerm);
- return success();
+ return rewriter
+ .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm)
+ .getResult();
}
};
@@ -165,14 +170,19 @@ struct TransferReadPermutationLowering
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
struct TransferWritePermutationLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_write inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -207,11 +217,11 @@ struct TransferWritePermutationLowering
op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
- op.getMask(), newInBoundsAttr);
-
- return success();
+ return rewriter
+ .create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
+ .getResult();
}
};
@@ -231,14 +241,19 @@ struct TransferWritePermutationLowering
/// vector<1x8x16xf32>
/// ```
struct TransferWriteNonPermutationLowering
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: Support transfer_write inside MaskOp case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
SmallVector<unsigned> permutation;
AffineMap map = op.getPermutationMap();
@@ -285,10 +300,11 @@ struct TransferWriteNonPermutationLowering
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
- newMask, newInBoundsAttr);
- return success();
+ return rewriter
+ .create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
+ .getResult();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index e48af3cd7aace..d63d47fe4481d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -46,6 +46,55 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
return
}
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<16xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index,
+// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
+// CHECK-NOT: vector.transpose
+// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+// CHECK: return %[[RES]]
+func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
+ %r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+ return %r : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME: %[[ARG_1:.*]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
+// CHECK-SAME: -> tensor<?x?x?x?xf32> {
+// CHECK-NOT: vector.transpose
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+// CHECK: return %[[R]] : tensor<?x?x?x?xf32>
+func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+
+ return %r : tensor<?x?x?x?xf32>
+}
+
+// transfer_write in MaskOp case not supported.
+// CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
+// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
+// CHECK-NOT: vector.broadcast
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+func.func @masked_non_permutation_xfer_write_fixed_width(
+ %arg0 : tensor<?x?x?x?xf32>,
+ %v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %mask = vector.create_mask %dim, %dim, %dim : vector<14x8x16xi1>
+ %0 = vector.mask %mask { vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+
+ return %0 : tensor<?x?x?x?xf32>
+}
+
///----------------------------------------------------------------------------------------
/// vector.transfer_read
///----------------------------------------------------------------------------------------
@@ -101,6 +150,39 @@ func.func @permutation_with_mask_xfer_read_scalable(%mem: memref<?x?xf32>, %dim_
return %1 : vector<8x[4]x2xf32>
}
+// transfer_read in MaskOp case not supported.
+// CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x1xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<4x1xi1>
+// CHECK-NOT: vector.transpose
+// CHECK: vector.mask %[[ARG_1]] { vector.transfer_read %[[ARG_0]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+func.func @masked_permutation_xfer_read_fixed_width(%arg0: tensor<?x1xf32>, %mask : vector<4x1xi1>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %3 = vector.mask %mask { vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = affine_map<(d0, d1) -> (d1, 0, d0)>} : tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
+ call @test.some_use(%3) : (vector<1x4x4xf32>) -> ()
+ return
+}
+func.func private @test.some_use(vector<1x4x4xf32>)
+
+// CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+// CHECK-NOT: vector.transpose
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+// CHECK: return %[[T_READ]] : vector<8x[4]x2xf32>
+func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
+
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+
+ %1 = vector.mask %mask { vector.transfer_read %t[%c0, %c0], %cst_0
+ {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1) -> (0, d1, d0)>}
+ : tensor<?x?xf32>, vector<8x[4]x2xf32> } :vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+ return %1 : vector<8x[4]x2xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
>From 1788c0c250bbdf1a5b2ec8dece380af09c17a6f7 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Tue, 14 May 2024 22:25:02 +0800
Subject: [PATCH 3/8] Fixup: test less for negative tests.
---
.../Vector/vector-transfer-permutation-lowering.mlir | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index d63d47fe4481d..a7acffbbbf397 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -54,7 +54,6 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
// CHECK-NOT: vector.transpose
// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
-// CHECK: return %[[RES]]
func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
%r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
return %r : tensor<?x?xf32>
@@ -66,9 +65,7 @@ func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val:
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
// CHECK-SAME: -> tensor<?x?x?x?xf32> {
// CHECK-NOT: vector.transpose
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
-// CHECK: return %[[R]] : tensor<?x?x?x?xf32>
func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
@@ -83,7 +80,6 @@ func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t:
// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
// CHECK-NOT: vector.broadcast
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
func.func @masked_non_permutation_xfer_write_fixed_width(
%arg0 : tensor<?x?x?x?xf32>,
@@ -169,9 +165,7 @@ func.func private @test.some_use(vector<1x4x4xf32>)
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
// CHECK-NOT: vector.transpose
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
-// CHECK: return %[[T_READ]] : vector<8x[4]x2xf32>
func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
%c0 = arith.constant 0 : index
>From c50ef197ca2e9a740bf274086f92087b83a49db5 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Wed, 15 May 2024 00:24:13 +0800
Subject: [PATCH 4/8] Fixup MaskableOpRewritePattern when transfer_write has no
result
---
.../Vector/Transforms/LowerVectorTransfer.cpp | 26 ++++++++++++-------
.../vector-transfer-permutation-lowering.mlir | 8 +++---
2 files changed, 20 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 7f5703b635068..81f7591a7d86f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -217,11 +217,14 @@ struct TransferWritePermutationLowering
op.getLoc(), op.getVector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
- return rewriter
- .create<vector::TransferWriteOp>(
- op.getLoc(), newVec, op.getSource(), op.getIndices(),
- AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr)
- .getResult();
+ auto newWrite = rewriter.create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
+ if (newWrite.hasPureTensorSemantics())
+ return newWrite.getResult();
+ // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+ rewriter.eraseOp(op);
+ return failure();
}
};
@@ -300,11 +303,14 @@ struct TransferWriteNonPermutationLowering
newInBoundsValues.push_back(op.isDimInBounds(i));
}
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
- return rewriter
- .create<vector::TransferWriteOp>(
- op.getLoc(), newVec, op.getSource(), op.getIndices(),
- AffineMapAttr::get(newMap), newMask, newInBoundsAttr)
- .getResult();
+ auto newWrite = rewriter.create<vector::TransferWriteOp>(
+ op.getLoc(), newVec, op.getSource(), op.getIndices(),
+ AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
+ if (newWrite.hasPureTensorSemantics())
+ return newWrite.getResult();
+ // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+ rewriter.eraseOp(op);
+ return failure();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index a7acffbbbf397..349dc1ab31d4c 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -53,7 +53,7 @@ func.func @permutation_with_mask_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %
// CHECK-SAME: %[[IDX:.*]]: index,
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
// CHECK-NOT: vector.transpose
-// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]][%[[IDX]], %[[IDX]]] {{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
+// CHECK: %[[RES:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_1]], %[[ARG_0]]{{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val: vector<16xf32>, %idx: index, %mask: vector<16xi1>) -> tensor<?x?xf32> {
%r = vector.mask %mask { vector.transfer_write %val, %t[%idx, %idx] {permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
return %r : tensor<?x?xf32>
@@ -65,7 +65,7 @@ func.func @masked_permutation_xfer_write_fixed_width(%t: tensor<?x?xf32>, %val:
// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>)
// CHECK-SAME: -> tensor<?x?x?x?xf32> {
// CHECK-NOT: vector.transpose
-// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]][%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = #[[MAP:.*]]} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
+// CHECK: %[[R:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[ARG_0]], %[[ARG_1]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t: tensor<?x?x?x?xf32>, %mask: vector<4x[8]xi1>) -> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%r = vector.mask %mask { vector.transfer_write %arg0, %t[%c0, %c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
@@ -80,7 +80,7 @@ func.func @masked_permutation_xfer_write_scalable(%arg0: vector<4x[8]xi16>, %t:
// CHECK-SAME: %[[ARG1:.*]]: vector<14x8x16xf32>
// CHECK-SAME: %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
// CHECK-NOT: vector.broadcast
-// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}}permutation_map = #[[MAP:.*]]} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
+// CHECK: %[[masked1:.*]] = vector.mask %0 { vector.transfer_write %[[ARG1]], %[[ARG0]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
func.func @masked_non_permutation_xfer_write_fixed_width(
%arg0 : tensor<?x?x?x?xf32>,
%v1 : vector<14x8x16xf32>, %dim : index) -> tensor<?x?x?x?xf32> {
@@ -165,7 +165,7 @@ func.func private @test.some_use(vector<1x4x4xf32>)
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
// CHECK-NOT: vector.transpose
-// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true], permutation_map = #[[MAP:.*]]} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
+// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
func.func @masked_permutation_xfer_read_scalable(%t: tensor<?x?xf32>, %mask : vector<2x[4]xi1>) -> vector<8x[4]x2xf32> {
%c0 = arith.constant 0 : index
>From 127b95098ab8443fc8aa024f237d9bc4be160901 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Fri, 17 May 2024 19:13:53 +0800
Subject: [PATCH 5/8] FixUp introduce #92526
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 81f7591a7d86f..751b78b586488 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -309,8 +309,7 @@ struct TransferWriteNonPermutationLowering
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
// In memref case, MaskableOpRewritePattern cannot replaceOp with result.
- rewriter.eraseOp(op);
- return failure();
+ return Value();
}
};
>From c11da46c2c8ba7169d1062a35e092182df66379c Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Fri, 17 May 2024 19:27:30 +0800
Subject: [PATCH 6/8] FixUp introduce #92526
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 751b78b586488..300b56fbc1776 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -223,8 +223,7 @@ struct TransferWritePermutationLowering
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
// In memref case, MaskableOpRewritePattern cannot replaceOp with result.
- rewriter.eraseOp(op);
- return failure();
+ return Value();
}
};
>From 9e1536d44ff91d379ac11e58c6202b2b26c13372 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Mon, 20 May 2024 21:53:19 +0800
Subject: [PATCH 7/8] Fixup : fix comments.
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 300b56fbc1776..c59012266ceb3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -222,7 +222,8 @@ struct TransferWritePermutationLowering
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
- // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+ // In the memref case there's no return value. Use empty value to signal
+ // success.
return Value();
}
};
@@ -307,7 +308,8 @@ struct TransferWriteNonPermutationLowering
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
if (newWrite.hasPureTensorSemantics())
return newWrite.getResult();
- // In memref case, MaskableOpRewritePattern cannot replaceOp with result.
+ // In the memref case there's no return value. Use empty value to signal
+ // success.
return Value();
}
};
>From 7c3cf8616209ace68b19969b5d6af7d3d6ca46e1 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Mon, 20 May 2024 22:16:28 +0800
Subject: [PATCH 8/8] fixup : MaskableOpRewritePattern and Value()
---
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index bf9694556f901..9c83acc76e77a 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -157,10 +157,14 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
if (failed(newOp))
return failure();
- if (rootOp->getNumResults() == 0 || *newOp == Value())
+ // Rewriting succeeded but there are no values to replace.
+ if (rootOp->getNumResults() == 0) {
rewriter.eraseOp(rootOp);
- else
+ } else {
+ assert(*newOp != Value() &&
+ "Cannot replace an op's use with an empty value.");
rewriter.replaceOp(rootOp, *newOp);
+ }
return success();
}
More information about the Mlir-commits
mailing list