[Mlir-commits] [mlir] [mlir][vector] Fix invalid `LoadOp` indices being created (PR #75519)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 14 11:30:40 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Rik Huijzer (rikhuijzer)
<details>
<summary>Changes</summary>
Fixes https://github.com/llvm/llvm-project/issues/71326.
The cause of the issue was that a new `LoadOp` was created which looked something like:
```mlir
%arg4 =
func.func main(%arg1 : index, %arg2 : index) {
%alloca_0 = memref.alloca() : memref<vector<1x32xi1>>
%1 = vector.type_cast %alloca_0 : memref<vector<1x32xi1>> to memref<1xvector<32xi1>>
%2 = memref.load %1[%arg1, %arg2] : memref<1xvector<32xi1>>
return
}
```
which crashed inside the `LoadOp::verify`. Note here that `%alloca_0` is 0 dimensional, `%1` has one dimension, but `memref.load` tries to index `%1` with two indices.
This is now fixed by using the fact that `unpackOneDim` always unpacks one dim
https://github.com/llvm/llvm-project/blob/1bce61e6b01b38e04260be4f422bbae59c34c766/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp#L897-L903
and so the `loadOp` should just index only that dimension.
---
Full diff: https://github.com/llvm/llvm-project/pull/75519.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+17-10)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+4-2)
- (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+17)
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+9)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..2026d0cd216a9e 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -369,7 +369,7 @@ struct Strategy<TransferReadOp> {
/// Retrieve the indices of the current StoreOp that stores into the buffer.
static void getBufferIndices(TransferReadOp xferOp,
SmallVector<Value, 8> &indices) {
- auto storeOp = getStoreOp(xferOp);
+ memref::StoreOp storeOp = getStoreOp(xferOp);
auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
@@ -591,8 +591,8 @@ struct PrepareTransferReadConversion
if (checkPrepareXferOp(xferOp, options).failed())
return failure();
- auto buffers = allocBuffers(rewriter, xferOp);
- auto *newXfer = rewriter.clone(*xferOp.getOperation());
+ BufferAllocs buffers = allocBuffers(rewriter, xferOp);
+ Operation *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.getMask()) {
dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
@@ -885,8 +885,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
- auto maskBuffer = getMaskBuffer(xferOp);
- auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +896,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
- auto castedMaskType = *unpackOneDim(maskBufferType);
+ auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
@@ -938,11 +938,18 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
b.setInsertionPoint(newXfer); // Insert load before newXfer.
SmallVector<Value, 8> loadIndices;
- Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
- // In case of broadcast: Use same indices to load from memref
- // as before.
- if (!xferOp.isBroadcastDim(0))
+ if (auto memrefType =
+ castedMaskBuffer.getType().dyn_cast<MemRefType>()) {
+ // If castedMaskBuffer is a memref, then one dim was
+ // unpacked; see above.
loadIndices.push_back(iv);
+ } else {
+ Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
+ // In case of broadcast: Use same indices to load from
+ // memref as before.
+ if (!xferOp.isBroadcastDim(0))
+ loadIndices.push_back(iv);
+ }
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 93327a28234ea9..48ec7040f271b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1615,8 +1615,10 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult LoadOp::verify() {
- if (getNumOperands() != 1 + getMemRefType().getRank())
- return emitOpError("incorrect number of indices for load");
+ if (getNumOperands() - 1 != getMemRefType().getRank()) {
+ return emitOpError("incorrect number of indices for load, expected ")
+ << getMemRefType().getRank() << " but got " << getNumOperands() - 1;
+ }
return success();
}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..953fcee0c372fa 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,23 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
// -----
+// Check that the `unpackOneDim` case in the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+ : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+ return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+// -----
+
// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
// FULL-UNROLL-NOT: vector.extract
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 55b759cbb3ce7c..f9b870f77266e1 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -896,6 +896,15 @@ func.func @bad_alloc_wrong_symbol_count() {
// -----
+func.func @load_invalid_memref_indexes() {
+ %0 = memref.alloca() : memref<10xi32>
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{incorrect number of indices for load, expected 1 but got 2}}
+ %1 = memref.load %0[%c0, %c0] : memref<10xi32>
+}
+
+// -----
+
func.func @test_store_zero_results() {
^bb0:
%0 = memref.alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
``````````
</details>
https://github.com/llvm/llvm-project/pull/75519
More information about the Mlir-commits
mailing list