[Mlir-commits] [mlir] [mlir][vector] Fix invalid `LoadOp` indices being created (PR #76292)
Rik Huijzer
llvmlistbot at llvm.org
Sat Dec 23 07:43:36 PST 2023
https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/76292
Second attempt at fixing https://github.com/llvm/llvm-project/issues/71326. The first attempt at https://github.com/llvm/llvm-project/pull/75519 was reverted because an integration test failed.
The cause of the issue was that a new `LoadOp` was created which looked something like:
```mlir
func.func main(%arg1 : index, %arg2 : index) {
%alloca_0 = memref.alloca() : memref<vector<1x32xi1>>
%1 = vector.type_cast %alloca_0 : memref<vector<1x32xi1>> to memref<1xvector<32xi1>>
%2 = memref.load %1[%arg1, %arg2] : memref<1xvector<32xi1>>
return
}
```
which crashed inside the `LoadOp::verify`. Note here that `%alloca_0` is the mask as can be seen from the `i1` element type and note it is 0 dimensional. Next, `%1` has one dimension, but `memref.load` tries to index it with two indices.
This issue occured in the following code (a simplified version of the bug report):
```mlir
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
func.func @main(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
: memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
return %3 : vector<1x1x1x1xi32>
}
```
After this patch, it is lowered to the following by `-convert-vector-to-scf`:
```mlir
func.func @main(%arg0: memref<1x1x1x1xi32>, %arg1: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%alloca = memref.alloca() : memref<vector<1x1x1x1xi32>>
%alloca_0 = memref.alloca() : memref<vector<1x1xi1>>
memref.store %arg1, %alloca_0[] : memref<vector<1x1xi1>>
%0 = vector.type_cast %alloca : memref<vector<1x1x1x1xi32>> to memref<1xvector<1x1x1xi32>>
%1 = vector.type_cast %alloca_0 : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
scf.for %arg2 = %c0 to %c1 step %c1 {
%3 = vector.type_cast %0 : memref<1xvector<1x1x1xi32>> to memref<1x1xvector<1x1xi32>>
scf.for %arg3 = %c0 to %c1 step %c1 {
%4 = vector.type_cast %3 : memref<1x1xvector<1x1xi32>> to memref<1x1x1xvector<1xi32>>
scf.for %arg4 = %c0 to %c1 step %c1 {
%5 = memref.load %1[%arg2] : memref<1xvector<1xi1>>
%6 = vector.transfer_read %arg0[%arg2, %c0, %c0, %c0], %c0_i32, %5 {in_bounds = [true]} : memref<1x1x1x1xi32>, vector<1xi32>
memref.store %6, %4[%arg2, %arg3, %arg4] : memref<1x1x1xvector<1xi32>>
}
}
}
%2 = memref.load %alloca[] : memref<vector<1x1x1x1xi32>>
return %2 : vector<1x1x1x1xi32>
}
```
What was causing the problems is that one dimension of the data buffer `%alloca` (eltype `i32`) is unpacked (`vector.type_cast`) inside the outmost loop (loop with index variable `%arg2`) and the nested loop (loop with index variable `%arg3`), whereas the mask buffer `%alloca_0` (eltype `i1`) is not unpacked in these loops.
Before this patch, the load indices would be determined by looking up the load indices for the *data* buffer load op. However, as shown in the specific example, when a permutation map is specified then the load indices from the data buffer load op start to differ from the indices for the mask op. To fix this, this patch ensures that the load indices for the *mask* buffer are used instead.
>From 0ff5a0ec09f7c26824bd90e6c7656222ee2448ae Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 23 Dec 2023 16:32:27 +0100
Subject: [PATCH] [mlir][vector] Fix invalid `LoadOp` indices being created
---
.../Conversion/VectorToSCF/VectorToSCF.cpp | 48 +++++++++++++------
.../Conversion/VectorToSCF/vector-to-scf.mlir | 37 ++++++++++++++
2 files changed, 71 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..13d2513a88804c 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -866,6 +866,31 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
this->setHasBoundedRewriteRecursion();
}
+ static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
+ SmallVector<Value, 8> &loadIndices,
+ Value iv) {
+ assert(xferOp.getMask() && "Expected transfer op to have mask");
+
+ // Add load indices from the previous iteration.
+ // The mask buffer depends on the permutation map, which makes determining
+ // the indices quite complex, so this is why we need to "look back" to the
+ // previous iteration to find the right indices.
+ Value maskBuffer = getMaskBuffer(xferOp);
+ for (OpOperand &use : maskBuffer.getUses()) {
+ // If there is no previous load op, then the indices are empty.
+ if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
+ Operation::operand_range prevIndices = loadOp.getIndices();
+ loadIndices.append(prevIndices.begin(), prevIndices.end());
+ break;
+ }
+ }
+
+ // In case of broadcast: Use same indices to load from memref
+ // as before.
+ if (!xferOp.isBroadcastDim(0))
+ loadIndices.push_back(iv);
+ }
+
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +898,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// Find and cast data buffer. How the buffer can be found depends on OpTy.
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
- auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+ Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
- auto castedDataType = unpackOneDim(dataBufferType);
+ FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
if (failed(castedDataType))
return failure();
@@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
- auto maskBuffer = getMaskBuffer(xferOp);
- auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
- auto castedMaskType = *unpackOneDim(maskBufferType);
+ auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
@@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If old transfer op has a mask: Set mask on new transfer op.
// Special case: If the mask of the old transfer op is 1D and
- // the
- // unpacked dim is not a broadcast, no mask is
- // needed on the new transfer op.
+ // the unpacked dim is not a broadcast, no mask is needed on
+ // the new transfer op.
if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
xferOp.getMaskType().getRank() > 1)) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(newXfer); // Insert load before newXfer.
SmallVector<Value, 8> loadIndices;
- Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
- // In case of broadcast: Use same indices to load from memref
- // as before.
- if (!xferOp.isBroadcastDim(0))
- loadIndices.push_back(iv);
-
+ getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
+ loadIndices, iv);
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
rewriter.updateRootInPlace(newXfer, [&]() {
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..8316b4005cc168 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
// -----
+// Check that the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+ : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+ return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+// -----
+
+// Check that the `TransferOpConversion` generates valid indices for the StoreOp.
+// This test is pulled from an integration test for ArmSVE.
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 2 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+ %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+ %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+ return %vector_a : vector<1x2x[4]xf32>
+}
+// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: memref.load
+
+// -----
+
// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
// FULL-UNROLL-NOT: vector.extract
More information about the Mlir-commits
mailing list