[Mlir-commits] [mlir] 6b21948 - [mlir][vector] Fix invalid `LoadOp` indices being created (#76292)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 3 04:46:56 PST 2024


Author: Rik Huijzer
Date: 2024-01-03T13:46:52+01:00
New Revision: 6b21948f26d69bad8c282db375906a8e0712d5f8

URL: https://github.com/llvm/llvm-project/commit/6b21948f26d69bad8c282db375906a8e0712d5f8
DIFF: https://github.com/llvm/llvm-project/commit/6b21948f26d69bad8c282db375906a8e0712d5f8.diff

LOG: [mlir][vector] Fix invalid `LoadOp` indices being created (#76292)

Fixes https://github.com/llvm/llvm-project/issues/71326.

This is the second PR. The first PR at
https://github.com/llvm/llvm-project/pull/75519 was reverted because an
integration test failed. The failed integration test was simplified and
added to the core MLIR tests. Compared to the first PR, the current PR
uses a more reliable approach. In summary, the current PR determines the
mask indices by looking up the _mask_ buffer load indices from the
previous iteration, whereas `main` looks up the indices for the _data_
buffer. The mask and data indices can differ when using a
`permutation_map`.

The cause of the issue was that a new `LoadOp` was created which looked
something like:
```mlir
func.func main(%arg1 : index, %arg2 : index) {
  %alloca_0 = memref.alloca() : memref<vector<1x32xi1>>
  %1 = vector.type_cast %alloca_0 : memref<vector<1x32xi1>> to memref<1xvector<32xi1>>
  %2 = memref.load %1[%arg1, %arg2] : memref<1xvector<32xi1>>
  return
}
```
which crashed inside the `LoadOp::verify`. Note here that `%alloca_0` is
the mask as can be seen from the `i1` element type and note it is 0
dimensional. Next, `%1` has one dimension, but `memref.load` tries to
index it with two indices.

This issue occured in the following code (a simplified version of the
bug report):
```mlir
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
func.func @main(%subview:  memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
  %c0 = arith.constant 0 : index
  %c0_i32 = arith.constant 0 : i32
  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
          : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
  return %3 : vector<1x1x1x1xi32>
}
```
After this patch, it is lowered to the following by
`-convert-vector-to-scf`:
```mlir
func.func @main(%arg0: memref<1x1x1x1xi32>, %arg1: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
  %c0_i32 = arith.constant 0 : i32
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %alloca = memref.alloca() : memref<vector<1x1x1x1xi32>>
  %alloca_0 = memref.alloca() : memref<vector<1x1xi1>>
  memref.store %arg1, %alloca_0[] : memref<vector<1x1xi1>>
  %0 = vector.type_cast %alloca : memref<vector<1x1x1x1xi32>> to memref<1xvector<1x1x1xi32>>
  %1 = vector.type_cast %alloca_0 : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
  scf.for %arg2 = %c0 to %c1 step %c1 {
    %3 = vector.type_cast %0 : memref<1xvector<1x1x1xi32>> to memref<1x1xvector<1x1xi32>>
    scf.for %arg3 = %c0 to %c1 step %c1 {
      %4 = vector.type_cast %3 : memref<1x1xvector<1x1xi32>> to memref<1x1x1xvector<1xi32>>
      scf.for %arg4 = %c0 to %c1 step %c1 {
        %5 = memref.load %1[%arg2] : memref<1xvector<1xi1>>
        %6 = vector.transfer_read %arg0[%arg2, %c0, %c0, %c0], %c0_i32, %5 {in_bounds = [true]} : memref<1x1x1x1xi32>, vector<1xi32>
        memref.store %6, %4[%arg2, %arg3, %arg4] : memref<1x1x1xvector<1xi32>>
      }
    }
  }
  %2 = memref.load %alloca[] : memref<vector<1x1x1x1xi32>>
  return %2 : vector<1x1x1x1xi32>
}
```
What was causing the problems is that one dimension of the data buffer
`%alloca` (eltype `i32`) is unpacked (`vector.type_cast`) inside the
outmost loop (loop with index variable `%arg2`) and the nested loop
(loop with index variable `%arg3`), whereas the mask buffer `%alloca_0`
(eltype `i1`) is not unpacked in these loops.

Before this patch, the load indices would be determined by looking up
the load indices for the *data* buffer load op. However, as shown in the
specific example, when a permutation map is specified then the load
indices from the data buffer load op start to differ from the indices
for the mask op. To fix this, this patch ensures that the load indices
for the *mask* buffer are used instead.

---------

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
    mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..a1aff1ab36a52b 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -866,6 +866,31 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     this->setHasBoundedRewriteRecursion();
   }
 
+  static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
+                                       SmallVectorImpl<Value> &loadIndices,
+                                       Value iv) {
+    assert(xferOp.getMask() && "Expected transfer op to have mask");
+
+    // Add load indices from the previous iteration.
+    // The mask buffer depends on the permutation map, which makes determining
+    // the indices quite complex, so this is why we need to "look back" to the
+    // previous iteration to find the right indices.
+    Value maskBuffer = getMaskBuffer(xferOp);
+    for (Operation *user : maskBuffer.getUsers()) {
+      // If there is no previous load op, then the indices are empty.
+      if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
+        Operation::operand_range prevIndices = loadOp.getIndices();
+        loadIndices.append(prevIndices.begin(), prevIndices.end());
+        break;
+      }
+    }
+
+    // In case of broadcast: Use same indices to load from memref
+    // as before.
+    if (!xferOp.isBroadcastDim(0))
+      loadIndices.push_back(iv);
+  }
+
   LogicalResult matchAndRewrite(OpTy xferOp,
                                 PatternRewriter &rewriter) const override {
     if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +898,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
 
     // Find and cast data buffer. How the buffer can be found depends on OpTy.
     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
-    auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+    Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
     auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
-    auto castedDataType = unpackOneDim(dataBufferType);
+    FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
     if (failed(castedDataType))
       return failure();
 
@@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
     // If the xferOp has a mask: Find and cast mask buffer.
     Value castedMaskBuffer;
     if (xferOp.getMask()) {
-      auto maskBuffer = getMaskBuffer(xferOp);
-      auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+      Value maskBuffer = getMaskBuffer(xferOp);
       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
         // Do not unpack a dimension of the mask, if:
         // * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
       } else {
         // It's safe to assume the mask buffer can be unpacked if the data
         // buffer was unpacked.
-        auto castedMaskType = *unpackOneDim(maskBufferType);
+        auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
+        MemRefType castedMaskType = *unpackOneDim(maskBufferType);
         castedMaskBuffer =
             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
       }
@@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
 
                 // If old transfer op has a mask: Set mask on new transfer op.
                 // Special case: If the mask of the old transfer op is 1D and
-                // the
-                //               unpacked dim is not a broadcast, no mask is
-                //               needed on the new transfer op.
+                // the unpacked dim is not a broadcast, no mask is needed on
+                // the new transfer op.
                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
                                          xferOp.getMaskType().getRank() > 1)) {
                   OpBuilder::InsertionGuard guard(b);
                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
 
                   SmallVector<Value, 8> loadIndices;
-                  Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
-                  // In case of broadcast: Use same indices to load from memref
-                  // as before.
-                  if (!xferOp.isBroadcastDim(0))
-                    loadIndices.push_back(iv);
-
+                  getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
+                                           loadIndices, iv);
                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
                                                        loadIndices);
                   rewriter.updateRootInPlace(newXfer, [&]() {

diff  --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..8316b4005cc168 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
 
 //  -----
 
+// Check that the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview:  memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+          : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+  return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+//  -----
+
+// Check that the `TransferOpConversion` generates valid indices for the StoreOp.
+// This test is pulled from an integration test for ArmSVE.
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 2 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+  %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+  %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+  return %vector_a : vector<1x2x[4]xf32>
+}
+// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: memref.load
+
+//  -----
+
 // FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
 func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
   // FULL-UNROLL-NOT: vector.extract


        


More information about the Mlir-commits mailing list