[Mlir-commits] [mlir] [mlir][vector] don't emit non-rank 1 masked load and store (PR #67656)

Cullen Rhodes llvmlistbot at llvm.org
Thu Sep 28 02:49:38 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/67656

The following patterns

  - TransferReadToVectorLoadLowering
  - TransferWriteToVectorStoreLowering

attempt to generate invalid vector.maskedload and vector.maskedstore ops for non rank-1 vector types. These ops operate on 1-D vectors. This patch adds a check to prevent this.

>From df191f4531f1bb27a722bf57c85ea69183ba8e6b Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 28 Sep 2023 09:35:17 +0000
Subject: [PATCH] [mlir][vector] don't emit non-rank 1 masked load and store

The following patterns

  - TransferReadToVectorLoadLowering
  - TransferWriteToVectorStoreLowering

attempt to generate invalid vector.maskedload and vector.maskedstore ops
for non rank-1 vector types. These ops operate on 1-D vectors. This
patch adds a check to prevent this.
---
 .../Vector/Transforms/LowerVectorTransfer.cpp | 13 ++++++++++
 .../vector-transfer-to-vector-load-store.mlir | 25 +++++++++++++++++++
 2 files changed, 38 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index bc3e47a71b43097..4e67e3ae306f276 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -455,6 +455,11 @@ struct TransferReadToVectorLoadLowering
     // Create vector load op.
     Operation *loadOp;
     if (read.getMask()) {
+      if (read.getVectorType().getRank() != 1)
+        // vector.maskedload operates on 1-D vectors.
+        return rewriter.notifyMatchFailure(
+            read, "vector type is not rank 1, can't create masked load");
+
       Value fill = rewriter.create<vector::SplatOp>(
           read.getLoc(), unbroadcastedVectorType, read.getPadding());
       loadOp = rewriter.create<vector::MaskedLoadOp>(
@@ -598,6 +603,14 @@ struct TransferWriteToVectorStoreLowering
         diag << "out of bounds dim: " << write;
       });
     if (write.getMask()) {
+      if (write.getVectorType().getRank() != 1)
+        // vector.maskedstore operates on 1-D vectors.
+        return rewriter.notifyMatchFailure(
+            write.getLoc(), [=](Diagnostic &diag) {
+              diag << "vector type is not rank 1, can't create masked store: "
+                   << write;
+            });
+
       rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
           write, write.getSource(), write.getIndices(), write.getMask(),
           write.getVector());
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index ce2108871e984f2..bc3a239ebb2cc79 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -367,3 +367,28 @@ transform.sequence failures(propagate) {
     transform.apply_patterns.vector.transfer_permutation_patterns
   } : !transform.op<"func.func">
 }
+
+// -----
+
+/// Verify vector.maskedload and vector.maskedstore ops that operate on 1-D
+/// vectors aren't generated for rank > 1 transfer ops.
+
+// CHECK-LABEL: @transfer_2D_masked
+// CHECK-NOT: vector.maskedload
+// CHECK-NOT: vector.maskedstore
+// CHECK: vector.transfer_read
+// CHECK: vector.transfer_write
+func.func @transfer_2D_masked(%mem : memref<?x?xf32>, %mask : vector<2x4xi1>) -> vector<2x4xf32> {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %res = vector.transfer_read %mem[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
+  vector.transfer_write %res, %mem[%c0, %c0], %mask {in_bounds = [true, true]} : vector<2x4xf32>, memref<?x?xf32>
+  return %res : vector<2x4xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+  transform.apply_patterns to %func_op {
+    transform.apply_patterns.vector.lower_transfer max_transfer_rank = 2
+  } : !transform.op<"func.func">
+}



More information about the Mlir-commits mailing list