[clang] [llvm] [mlir] [compiler-rt] [clang-tools-extra] [flang] [mlir][vector] Fix invalid `LoadOp` indices being created (PR #76292)
Rik Huijzer via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 1 22:43:20 PST 2024
https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/76292
>From 0ff5a0ec09f7c26824bd90e6c7656222ee2448ae Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 23 Dec 2023 16:32:27 +0100
Subject: [PATCH] [mlir][vector] Fix invalid `LoadOp` indices being created
---
.../Conversion/VectorToSCF/VectorToSCF.cpp | 48 +++++++++++++------
.../Conversion/VectorToSCF/vector-to-scf.mlir | 37 ++++++++++++++
2 files changed, 71 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..13d2513a88804c 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -866,6 +866,31 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
this->setHasBoundedRewriteRecursion();
}
+ static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
+ SmallVector<Value, 8> &loadIndices,
+ Value iv) {
+ assert(xferOp.getMask() && "Expected transfer op to have mask");
+
+ // Add load indices from the previous iteration.
+ // The mask buffer depends on the permutation map, which makes determining
+ // the indices quite complex, so this is why we need to "look back" to the
+ // previous iteration to find the right indices.
+ Value maskBuffer = getMaskBuffer(xferOp);
+ for (OpOperand &use : maskBuffer.getUses()) {
+ // If there is no previous load op, then the indices are empty.
+ if (auto loadOp = dyn_cast<memref::LoadOp>(use.getOwner())) {
+ Operation::operand_range prevIndices = loadOp.getIndices();
+ loadIndices.append(prevIndices.begin(), prevIndices.end());
+ break;
+ }
+ }
+
+ // In case of broadcast: Use same indices to load from memref
+ // as before.
+ if (!xferOp.isBroadcastDim(0))
+ loadIndices.push_back(iv);
+ }
+
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +898,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// Find and cast data buffer. How the buffer can be found depends on OpTy.
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
- auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+ Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
- auto castedDataType = unpackOneDim(dataBufferType);
+ FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
if (failed(castedDataType))
return failure();
@@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
- auto maskBuffer = getMaskBuffer(xferOp);
- auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
- auto castedMaskType = *unpackOneDim(maskBufferType);
+ auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
@@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If old transfer op has a mask: Set mask on new transfer op.
// Special case: If the mask of the old transfer op is 1D and
- // the
- // unpacked dim is not a broadcast, no mask is
- // needed on the new transfer op.
+ // the unpacked dim is not a broadcast, no mask is needed on
+ // the new transfer op.
if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
xferOp.getMaskType().getRank() > 1)) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(newXfer); // Insert load before newXfer.
SmallVector<Value, 8> loadIndices;
- Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
- // In case of broadcast: Use same indices to load from memref
- // as before.
- if (!xferOp.isBroadcastDim(0))
- loadIndices.push_back(iv);
-
+ getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
+ loadIndices, iv);
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
rewriter.updateRootInPlace(newXfer, [&]() {
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..8316b4005cc168 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
// -----
+// Check that the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+ : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+ return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+// -----
+
+// Check that the `TransferOpConversion` generates valid indices for the StoreOp.
+// This test is pulled from an integration test for ArmSVE.
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 2 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+ %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+ %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+ return %vector_a : vector<1x2x[4]xf32>
+}
+// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: memref.load
+
+// -----
+
// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
// FULL-UNROLL-NOT: vector.extract
More information about the llvm-commits
mailing list