[Mlir-commits] [mlir] [mlir][vector] don't emit non-rank 1 masked load and store (PR #67656)
Cullen Rhodes
llvmlistbot at llvm.org
Thu Sep 28 02:49:38 PDT 2023
https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/67656
The following patterns
- TransferReadToVectorLoadLowering
- TransferWriteToVectorStoreLowering
attempt to generate invalid vector.maskedload and vector.maskedstore ops for non rank-1 vector types. These ops operate on 1-D vectors. This patch adds a check to prevent this.
>From df191f4531f1bb27a722bf57c85ea69183ba8e6b Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 28 Sep 2023 09:35:17 +0000
Subject: [PATCH] [mlir][vector] don't emit non-rank 1 masked load and store
The following patterns
- TransferReadToVectorLoadLowering
- TransferWriteToVectorStoreLowering
attempt to generate invalid vector.maskedload and vector.maskedstore ops
for non rank-1 vector types. These ops operate on 1-D vectors. This
patch adds a check to prevent this.
---
.../Vector/Transforms/LowerVectorTransfer.cpp | 13 ++++++++++
.../vector-transfer-to-vector-load-store.mlir | 25 +++++++++++++++++++
2 files changed, 38 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index bc3e47a71b43097..4e67e3ae306f276 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -455,6 +455,11 @@ struct TransferReadToVectorLoadLowering
// Create vector load op.
Operation *loadOp;
if (read.getMask()) {
+ if (read.getVectorType().getRank() != 1)
+ // vector.maskedload operates on 1-D vectors.
+ return rewriter.notifyMatchFailure(
+ read, "vector type is not rank 1, can't create masked load");
+
Value fill = rewriter.create<vector::SplatOp>(
read.getLoc(), unbroadcastedVectorType, read.getPadding());
loadOp = rewriter.create<vector::MaskedLoadOp>(
@@ -598,6 +603,14 @@ struct TransferWriteToVectorStoreLowering
diag << "out of bounds dim: " << write;
});
if (write.getMask()) {
+ if (write.getVectorType().getRank() != 1)
+ // vector.maskedstore operates on 1-D vectors.
+ return rewriter.notifyMatchFailure(
+ write.getLoc(), [=](Diagnostic &diag) {
+ diag << "vector type is not rank 1, can't create masked store: "
+ << write;
+ });
+
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
write, write.getSource(), write.getIndices(), write.getMask(),
write.getVector());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index ce2108871e984f2..bc3a239ebb2cc79 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -367,3 +367,28 @@ transform.sequence failures(propagate) {
transform.apply_patterns.vector.transfer_permutation_patterns
} : !transform.op<"func.func">
}
+
+// -----
+
+/// Verify vector.maskedload and vector.maskedstore ops that operate on 1-D
+/// vectors aren't generated for rank > 1 transfer ops.
+
+// CHECK-LABEL: @transfer_2D_masked
+// CHECK-NOT: vector.maskedload
+// CHECK-NOT: vector.maskedstore
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_write
+func.func @transfer_2D_masked(%mem : memref<?x?xf32>, %mask : vector<2x4xi1>) -> vector<2x4xf32> {
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
+ vector.transfer_write %res, %mem[%c0, %c0], %mask {in_bounds = [true, true]} : vector<2x4xf32>, memref<?x?xf32>
+ return %res : vector<2x4xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_transfer max_transfer_rank = 2
+ } : !transform.op<"func.func">
+}
More information about the Mlir-commits
mailing list