[Mlir-commits] [mlir] [mlir][affine] Fix dim index out of bounds crash (PR #73266)
Rik Huijzer
llvmlistbot at llvm.org
Thu Nov 23 13:35:36 PST 2023
https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/73266
>From b2147b28969457af0a4229bb0e6d0f00c6294797 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Thu, 23 Nov 2023 22:26:35 +0100
Subject: [PATCH 1/2] [mlir][affine] Fix dim index out of bounds crash
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 40 +++++++++++--------
.../FuncToSPIRV/func-ops-to-spirv.mlir | 12 ++++++
mlir/test/Dialect/Affine/invalid.mlir | 4 +-
.../Dialect/Affine/load-store-invalid.mlir | 12 +++---
4 files changed, 44 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d22a7539fb75018..d6e640ddd8f25d5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -317,9 +317,16 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
/// for `region`.
template <typename AnyMemRefDefOp>
-static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
+static bool isMemRefSizeValidSymbol(ShapedDimOpInterface dimOp,
+ AnyMemRefDefOp memrefDefOp, unsigned index,
Region *region) {
- auto memRefType = memrefDefOp.getType();
+ MemRefType memRefType = memrefDefOp.getType();
+
+ // Dimension index is out of bounds.
+ if (index >= memRefType.getRank()) {
+ return false;
+ }
+
// Statically shaped.
if (!memRefType.isDynamicDim(index))
return true;
@@ -351,7 +358,9 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
int64_t i = index.value();
return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
- [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
+ [&](auto memRefDefOp) {
+ return isMemRefSizeValidSymbol(dimOp, memRefDefOp, i, region);
+ })
.Default([](Operation *) { return false; });
}
@@ -1651,19 +1660,19 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
if (!idx.getType().isIndex())
return emitOpError("src index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("src index must be a dimension or symbol identifier");
+ return emitOpError("src index must be a valid dimension or symbol identifier");
}
for (auto idx : getDstIndices()) {
if (!idx.getType().isIndex())
return emitOpError("dst index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("dst index must be a dimension or symbol identifier");
+ return emitOpError("dst index must be a valid dimension or symbol identifier");
}
for (auto idx : getTagIndices()) {
if (!idx.getType().isIndex())
return emitOpError("tag index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("tag index must be a dimension or symbol identifier");
+ return emitOpError("tag index must be a valid dimension or symbol identifier");
}
return success();
}
@@ -1752,7 +1761,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
if (!idx.getType().isIndex())
return emitOpError("index to dma_wait must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("index must be a dimension or symbol identifier");
+ return emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
}
@@ -2913,8 +2922,7 @@ static void composeSetAndOperands(IntegerSet &set,
}
/// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
- SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
auto set = getIntegerSet();
SmallVector<Value, 4> operands(getOperands());
composeSetAndOperands(set, operands);
@@ -3005,18 +3013,18 @@ static LogicalResult
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
Operation::operand_range mapOperands,
MemRefType memrefType, unsigned numIndexOperands) {
- AffineMap map = mapAttr.getValue();
- if (map.getNumResults() != memrefType.getRank())
- return op->emitOpError("affine map num results must equal memref rank");
- if (map.getNumInputs() != numIndexOperands)
- return op->emitOpError("expects as many subscripts as affine map inputs");
+ AffineMap map = mapAttr.getValue();
+ if (map.getNumResults() != memrefType.getRank())
+ return op->emitOpError("affine map num results must equal memref rank");
+ if (map.getNumInputs() != numIndexOperands)
+ return op->emitOpError("expects as many subscripts as affine map inputs");
Region *scope = getAffineScope(op);
for (auto idx : mapOperands) {
if (!idx.getType().isIndex())
return op->emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return op->emitOpError("index must be a dimension or symbol identifier");
+ return op->emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
@@ -3605,7 +3613,7 @@ LogicalResult AffinePrefetchOp::verify() {
Region *scope = getAffineScope(*this);
for (auto idx : getMapOperands()) {
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("index must be a dimension or symbol identifier");
+ return emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
}
diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index 759ab2d6c358c8a..b94d271fc197014 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -49,3 +49,15 @@ func.func @call_functions(%arg0: index) -> index {
}
// -----
+
+func.func @dim_out_of_bounds() {
+ %c6 = arith.constant 6 : index
+ %alloc_4 = memref.alloc() : memref<4xi64>
+ %dim = memref.dim %alloc_4, %c6 : memref<4xi64> // Out of bounds; UB.
+ %alloca_100 = memref.alloca() : memref<100xi64>
+ // expected-error at +1 {{'affine.vector_load' op index must be a valid dimension or symbol identifier}}
+ %70 = affine.vector_load %alloca_100[%dim] : memref<100xi64>, vector<31xi64>
+ return
+}
+
+// -----
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 72864516b459a51..60f13102f551569 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -55,7 +55,7 @@ func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
"unknown"() ({
^bb0(%arg: index):
affine.load %M[%arg] : memref<10xi32>
- // expected-error at -1 {{index must be a dimension or symbol identifier}}
+ // expected-error at -1 {{index must be a valid dimension or symbol identifier}}
cf.br ^bb1
^bb1:
cf.br ^bb1
@@ -521,7 +521,7 @@ func.func @dynamic_dimension_index() {
%idx = "unknown.test"() : () -> (index)
%memref = "unknown.test"() : () -> memref<?x?xf32>
%dim = memref.dim %memref, %idx : memref<?x?xf32>
- // expected-error @below {{op index must be a dimension or symbol identifier}}
+ // expected-error @below {{op index must be a valid dimension or symbol identifier}}
affine.load %memref[%dim, %dim] : memref<?x?xf32>
"unknown.terminator"() : () -> ()
}) : () -> ()
diff --git a/mlir/test/Dialect/Affine/load-store-invalid.mlir b/mlir/test/Dialect/Affine/load-store-invalid.mlir
index 482d2f35e094923..01d6b25dee695bb 100644
--- a/mlir/test/Dialect/Affine/load-store-invalid.mlir
+++ b/mlir/test/Dialect/Affine/load-store-invalid.mlir
@@ -37,7 +37,7 @@ func.func @load_non_affine_index(%arg0 : index) {
%0 = memref.alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
%1 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
%v = affine.load %0[%1] : memref<10xf32>
}
return
@@ -50,7 +50,7 @@ func.func @store_non_affine_index(%arg0 : index) {
%1 = arith.constant 11.0 : f32
affine.for %i0 = 0 to 10 {
%2 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
affine.store %1, %0[%2] : memref<10xf32>
}
return
@@ -84,7 +84,7 @@ func.func @dma_start_non_affine_src_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op src index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op src index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -101,7 +101,7 @@ func.func @dma_start_non_affine_dst_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op dst index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op dst index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -118,7 +118,7 @@ func.func @dma_start_non_affine_tag_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op tag index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op tag index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -135,7 +135,7 @@ func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
}
return
>From b771a6db05fcaed7c1b09b64279ebd8c97440c72 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Thu, 23 Nov 2023 22:35:21 +0100
Subject: [PATCH 2/2] Move comment into test func name
---
mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index b94d271fc197014..a09f1697fd72494 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -50,10 +50,10 @@ func.func @call_functions(%arg0: index) -> index {
// -----
-func.func @dim_out_of_bounds() {
+func.func @dim_index_out_of_bounds() {
%c6 = arith.constant 6 : index
%alloc_4 = memref.alloc() : memref<4xi64>
- %dim = memref.dim %alloc_4, %c6 : memref<4xi64> // Out of bounds; UB.
+ %dim = memref.dim %alloc_4, %c6 : memref<4xi64>
%alloca_100 = memref.alloca() : memref<100xi64>
// expected-error at +1 {{'affine.vector_load' op index must be a valid dimension or symbol identifier}}
%70 = affine.vector_load %alloca_100[%dim] : memref<100xi64>, vector<31xi64>
More information about the Mlir-commits
mailing list