[Mlir-commits] [mlir] [mlir][vector] Fix invalid `LoadOp` indices being created (PR #75519)

Rik Huijzer llvmlistbot at llvm.org
Sat Dec 16 10:06:46 PST 2023


https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/75519

>From ddeb00d10b632864a8c605ebcd0669033018a616 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Thu, 14 Dec 2023 20:16:18 +0100
Subject: [PATCH 1/2] [mlir][vector] Fix invalid `LoadOp` indices being created

---
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 27 ++++++++++++-------
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  6 +++--
 .../Conversion/VectorToSCF/vector-to-scf.mlir | 17 ++++++++++++
 mlir/test/Dialect/MemRef/invalid.mlir         |  9 +++++++
 4 files changed, 47 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..2026d0cd216a9e 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -369,7 +369,7 @@ struct Strategy<TransferReadOp> {
   /// Retrieve the indices of the current StoreOp that stores into the buffer.
   static void getBufferIndices(TransferReadOp xferOp,
                                SmallVector<Value, 8> &indices) {
-    auto storeOp = getStoreOp(xferOp);
+    memref::StoreOp storeOp = getStoreOp(xferOp);
     auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
     indices.append(prevIndices.begin(), prevIndices.end());
   }
@@ -591,8 +591,8 @@ struct PrepareTransferReadConversion
     if (checkPrepareXferOp(xferOp, options).failed())
       return failure();
 
-    auto buffers = allocBuffers(rewriter, xferOp);
-    auto *newXfer = rewriter.clone(*xferOp.getOperation());
+    BufferAllocs buffers = allocBuffers(rewriter, xferOp);
+    Operation *newXfer = rewriter.clone(*xferOp.getOperation());
     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
     if (xferOp.getMask()) {
       dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
@@ -885,8 +885,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     // If the xferOp has a mask: Find and cast mask buffer.
     Value castedMaskBuffer;
     if (xferOp.getMask()) {
-      auto maskBuffer = getMaskBuffer(xferOp);
-      auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+      Value maskBuffer = getMaskBuffer(xferOp);
       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
         // Do not unpack a dimension of the mask, if:
         // * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +896,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
       } else {
         // It's safe to assume the mask buffer can be unpacked if the data
         // buffer was unpacked.
-        auto castedMaskType = *unpackOneDim(maskBufferType);
+        auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+        MemRefType castedMaskType = *unpackOneDim(maskBufferType);
         castedMaskBuffer =
             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
       }
@@ -938,11 +938,18 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
 
                   SmallVector<Value, 8> loadIndices;
-                  Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
-                  // In case of broadcast: Use same indices to load from memref
-                  // as before.
-                  if (!xferOp.isBroadcastDim(0))
+                  if (auto memrefType =
+                          castedMaskBuffer.getType().dyn_cast<MemRefType>()) {
+                    // If castedMaskBuffer is a memref, then one dim was
+                    // unpacked; see above.
                     loadIndices.push_back(iv);
+                  } else {
+                    Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
+                    // In case of broadcast: Use same indices to load from
+                    // memref as before.
+                    if (!xferOp.isBroadcastDim(0))
+                      loadIndices.push_back(iv);
+                  }
 
                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
                                                        loadIndices);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 93327a28234ea9..48ec7040f271b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1615,8 +1615,10 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult LoadOp::verify() {
-  if (getNumOperands() != 1 + getMemRefType().getRank())
-    return emitOpError("incorrect number of indices for load");
+  if (getNumOperands() - 1 != getMemRefType().getRank()) {
+    return emitOpError("incorrect number of indices for load, expected ")
+           << getMemRefType().getRank() << " but got " << getNumOperands() - 1;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..953fcee0c372fa 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,23 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
 
 //  -----
 
+// Check that the `unpackOneDim` case in the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview:  memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+          : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+  return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+//  -----
+
 // FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
 func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
   // FULL-UNROLL-NOT: vector.extract
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 55b759cbb3ce7c..f9b870f77266e1 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -896,6 +896,15 @@ func.func @bad_alloc_wrong_symbol_count() {
 
 // -----
 
+func.func @load_invalid_memref_indexes() {
+  %0 = memref.alloca() : memref<10xi32>
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{incorrect number of indices for load, expected 1 but got 2}}
+  %1 = memref.load %0[%c0, %c0] : memref<10xi32>
+}
+
+// -----
+
 func.func @test_store_zero_results() {
 ^bb0:
   %0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>

>From 215908d3041e7cde9a912024e064e68e12236e51 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 16 Dec 2023 19:06:40 +0100
Subject: [PATCH 2/2] Use `getIndices().size()`

Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 48ec7040f271b1..46fa27bba34ccc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1615,10 +1615,10 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult LoadOp::verify() {
-  if (getNumOperands() - 1 != getMemRefType().getRank()) {
-    return emitOpError("incorrect number of indices for load, expected ")
-           << getMemRefType().getRank() << " but got " << getNumOperands() - 1;
-  }
+ if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
+   return emitOpError("incorrect number of indices for load, expected ")
+          << getMemRefType().getRank() << " but got " << getIndices().size();
+ }
   return success();
 }
 



More information about the Mlir-commits mailing list