[clang] [llvm] [mlir] [compiler-rt] [clang-tools-extra] [flang] [mlir][vector] Fix invalid `LoadOp` indices being created (PR #76292)

Rik Huijzer via cfe-commits cfe-commits at lists.llvm.org
Mon Jan 1 22:43:20 PST 2024


https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/76292

>From 0ff5a0ec09f7c26824bd90e6c7656222ee2448ae Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 23 Dec 2023 16:32:27 +0100
Subject: [PATCH] [mlir][vector] Fix invalid `LoadOp` indices being created

---
 .../Conversion/VectorToSCF/VectorToSCF.cpp    | 48 +++++++++++++------
 .../Conversion/VectorToSCF/vector-to-scf.mlir | 37 ++++++++++++++
 2 files changed, 71 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..13d2513a88804c 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -866,6 +866,31 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     this->setHasBoundedRewriteRecursion();
   }
 
+  static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
+                                       SmallVector<Value, 8> &loadIndices,
+                                       Value iv) {
+    assert(xferOp.getMask() && "Expected transfer op to have mask");
+
+    // Add load indices from the previous iteration.
+    // The mask buffer depends on the permutation map, which makes determining
+    // the indices quite complex, so this is why we need to "look back" to the
+    // previous iteration to find the right indices.
+    Value maskBuffer = getMaskBuffer(xferOp);
+    for (OpOperand &use : maskBuffer.getUses()) {
+      // If there is no previous load op, then the indices are empty.
+      if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
+        Operation::operand_range prevIndices = loadOp.getIndices();
+        loadIndices.append(prevIndices.begin(), prevIndices.end());
+        break;
+      }
+    }
+
+    // In case of broadcast: Use same indices to load from memref
+    // as before.
+    if (!xferOp.isBroadcastDim(0))
+      loadIndices.push_back(iv);
+  }
+
   LogicalResult matchAndRewrite(OpTy xferOp,
                                 PatternRewriter &rewriter) const override {
     if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +898,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
 
     // Find and cast data buffer. How the buffer can be found depends on OpTy.
     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
-    auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+    Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
     auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
-    auto castedDataType = unpackOneDim(dataBufferType);
+    FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
     if (failed(castedDataType))
       return failure();
 
@@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     // If the xferOp has a mask: Find and cast mask buffer.
     Value castedMaskBuffer;
     if (xferOp.getMask()) {
-      auto maskBuffer = getMaskBuffer(xferOp);
-      auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+      Value maskBuffer = getMaskBuffer(xferOp);
       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
         // Do not unpack a dimension of the mask, if:
         // * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
       } else {
         // It's safe to assume the mask buffer can be unpacked if the data
         // buffer was unpacked.
-        auto castedMaskType = *unpackOneDim(maskBufferType);
+        auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+        MemRefType castedMaskType = *unpackOneDim(maskBufferType);
         castedMaskBuffer =
             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
       }
@@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
 
                 // If old transfer op has a mask: Set mask on new transfer op.
                 // Special case: If the mask of the old transfer op is 1D and
-                // the
-                //               unpacked dim is not a broadcast, no mask is
-                //               needed on the new transfer op.
+                // the unpacked dim is not a broadcast, no mask is needed on
+                // the new transfer op.
                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
                                          xferOp.getMaskType().getRank() > 1)) {
                   OpBuilder::InsertionGuard guard(b);
                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
 
                   SmallVector<Value, 8> loadIndices;
-                  Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
-                  // In case of broadcast: Use same indices to load from memref
-                  // as before.
-                  if (!xferOp.isBroadcastDim(0))
-                    loadIndices.push_back(iv);
-
+                  getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
+                                           loadIndices, iv);
                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
                                                        loadIndices);
                   rewriter.updateRootInPlace(newXfer, [&]() {
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..8316b4005cc168 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
 
 //  -----
 
+// Check that the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview:  memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+          : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+  return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+//  -----
+
+// Check that the `TransferOpConversion` generates valid indices for the StoreOp.
+// This test is pulled from an integration test for ArmSVE.
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 2 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+  %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+  %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+  return %vector_a : vector<1x2x[4]xf32>
+}
+// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: memref.load
+
+//  -----
+
 // FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
 func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
   // FULL-UNROLL-NOT: vector.extract



More information about the cfe-commits mailing list