[Mlir-commits] [mlir] 0170498 - [MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern (#92426)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 12 01:01:03 PDT 2024
Author: Hugo Trachino
Date: 2024-06-12T09:00:59+01:00
New Revision: 0170498a7ddf84a404527ec3f0e82a4a3d10869a
URL: https://github.com/llvm/llvm-project/commit/0170498a7ddf84a404527ec3f0e82a4a3d10869a
DIFF: https://github.com/llvm/llvm-project/commit/0170498a7ddf84a404527ec3f0e82a4a3d10869a.diff
LOG: [MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern (#92426)
Implements `TransferOpReduceRank` as a `MaskableOpRewritePattern`.
Allowing to exit gracefully when run on a `vector::transfer_read`
located inside a `vector::MaskOp` instead of generating `error: 'vector.mask'
op expects only one operation to mask` because the
pattern generated multiple ops inside the MaskOp.
Split of https://github.com/llvm/llvm-project/pull/90835
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index c59012266ceb3..6bfb2eb689418 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -322,14 +322,20 @@ struct TransferWriteNonPermutationLowering
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
-struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+struct TransferOpReduceRank
+ : public MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp op,
- PatternRewriter &rewriter) const override {
+ FailureOr<mlir::Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp op,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return rewriter.notifyMatchFailure(op, "0-d corner case not supported");
+ // TODO: support masked case.
+ if (maskOp)
+ return rewriter.notifyMatchFailure(op, "Masked case not supported");
AffineMap map = op.getPermutationMap();
unsigned numLeadingBroadcast = 0;
@@ -369,9 +375,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
- newRead);
- return success();
+ return rewriter
+ .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+ .getVector();
}
SmallVector<int64_t> newShape(
@@ -393,9 +399,9 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
newInBoundsAttr);
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
- newRead);
- return success();
+ return rewriter
+ .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
+ .getVector();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 349dc1ab31d4c..0cd134717b1a0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
+// CHECK: func.func @transfer_read_reduce_rank_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
+// CHECK: %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
+// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
+func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %1 = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
+ {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
+ : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
+ return %1 : vector<8x[4]x2x3xf32>
+}
+
+// Masked case not supported.
+// CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
+// CHECK-SAME: %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
+// CHECK-SAME: %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
+// CHECK-NOT: vector.broadcast
+// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+func.func @masked_transfer_read_reduce_rank(%mem: memref<?x?x?x?xf32>, %dim: index) -> vector<8x[4]x2x3xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
+ %res = vector.mask %mask { vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
+ {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
+ : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+ return %res : vector<8x[4]x2x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list