[Mlir-commits] [mlir] [mlir][affine] Fix dim index out of bounds crash (PR #73266)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 23 13:35:06 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-affine
Author: Rik Huijzer (rikhuijzer)
<details>
<summary>Changes</summary>
This PR suggests a way to fix https://github.com/llvm/llvm-project/issues/70418. It now throws an error if the `index` operand for `memref.dim` is out of bounds. Catching it in the verifier was not possible because the constant value is not yet available at that point. Unfortunately, the error is not very descriptive since it was only possible to propagate boolean up.
---
Full diff: https://github.com/llvm/llvm-project/pull/73266.diff
4 Files Affected:
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+24-16)
- (modified) mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir (+12)
- (modified) mlir/test/Dialect/Affine/invalid.mlir (+2-2)
- (modified) mlir/test/Dialect/Affine/load-store-invalid.mlir (+6-6)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d22a7539fb75018..d6e640ddd8f25d5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -317,9 +317,16 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
/// for `region`.
template <typename AnyMemRefDefOp>
-static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
+static bool isMemRefSizeValidSymbol(ShapedDimOpInterface dimOp,
+ AnyMemRefDefOp memrefDefOp, unsigned index,
Region *region) {
- auto memRefType = memrefDefOp.getType();
+ MemRefType memRefType = memrefDefOp.getType();
+
+ // Dimension index is out of bounds.
+ if (index >= memRefType.getRank()) {
+ return false;
+ }
+
// Statically shaped.
if (!memRefType.isDynamicDim(index))
return true;
@@ -351,7 +358,9 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
int64_t i = index.value();
return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
- [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
+ [&](auto memRefDefOp) {
+ return isMemRefSizeValidSymbol(dimOp, memRefDefOp, i, region);
+ })
.Default([](Operation *) { return false; });
}
@@ -1651,19 +1660,19 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
if (!idx.getType().isIndex())
return emitOpError("src index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("src index must be a dimension or symbol identifier");
+ return emitOpError("src index must be a valid dimension or symbol identifier");
}
for (auto idx : getDstIndices()) {
if (!idx.getType().isIndex())
return emitOpError("dst index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("dst index must be a dimension or symbol identifier");
+ return emitOpError("dst index must be a valid dimension or symbol identifier");
}
for (auto idx : getTagIndices()) {
if (!idx.getType().isIndex())
return emitOpError("tag index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("tag index must be a dimension or symbol identifier");
+ return emitOpError("tag index must be a valid dimension or symbol identifier");
}
return success();
}
@@ -1752,7 +1761,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
if (!idx.getType().isIndex())
return emitOpError("index to dma_wait must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("index must be a dimension or symbol identifier");
+ return emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
}
@@ -2913,8 +2922,7 @@ static void composeSetAndOperands(IntegerSet &set,
}
/// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
- SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
auto set = getIntegerSet();
SmallVector<Value, 4> operands(getOperands());
composeSetAndOperands(set, operands);
@@ -3005,18 +3013,18 @@ static LogicalResult
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
Operation::operand_range mapOperands,
MemRefType memrefType, unsigned numIndexOperands) {
- AffineMap map = mapAttr.getValue();
- if (map.getNumResults() != memrefType.getRank())
- return op->emitOpError("affine map num results must equal memref rank");
- if (map.getNumInputs() != numIndexOperands)
- return op->emitOpError("expects as many subscripts as affine map inputs");
+ AffineMap map = mapAttr.getValue();
+ if (map.getNumResults() != memrefType.getRank())
+ return op->emitOpError("affine map num results must equal memref rank");
+ if (map.getNumInputs() != numIndexOperands)
+ return op->emitOpError("expects as many subscripts as affine map inputs");
Region *scope = getAffineScope(op);
for (auto idx : mapOperands) {
if (!idx.getType().isIndex())
return op->emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return op->emitOpError("index must be a dimension or symbol identifier");
+ return op->emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
@@ -3605,7 +3613,7 @@ LogicalResult AffinePrefetchOp::verify() {
Region *scope = getAffineScope(*this);
for (auto idx : getMapOperands()) {
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("index must be a dimension or symbol identifier");
+ return emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
}
diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index 759ab2d6c358c8a..b94d271fc197014 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -49,3 +49,15 @@ func.func @call_functions(%arg0: index) -> index {
}
// -----
+
+func.func @dim_out_of_bounds() {
+ %c6 = arith.constant 6 : index
+ %alloc_4 = memref.alloc() : memref<4xi64>
+ %dim = memref.dim %alloc_4, %c6 : memref<4xi64> // Out of bounds; UB.
+ %alloca_100 = memref.alloca() : memref<100xi64>
+ // expected-error at +1 {{'affine.vector_load' op index must be a valid dimension or symbol identifier}}
+ %70 = affine.vector_load %alloca_100[%dim] : memref<100xi64>, vector<31xi64>
+ return
+}
+
+// -----
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 72864516b459a51..60f13102f551569 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -55,7 +55,7 @@ func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
"unknown"() ({
^bb0(%arg: index):
affine.load %M[%arg] : memref<10xi32>
- // expected-error at -1 {{index must be a dimension or symbol identifier}}
+ // expected-error at -1 {{index must be a valid dimension or symbol identifier}}
cf.br ^bb1
^bb1:
cf.br ^bb1
@@ -521,7 +521,7 @@ func.func @dynamic_dimension_index() {
%idx = "unknown.test"() : () -> (index)
%memref = "unknown.test"() : () -> memref<?x?xf32>
%dim = memref.dim %memref, %idx : memref<?x?xf32>
- // expected-error @below {{op index must be a dimension or symbol identifier}}
+ // expected-error @below {{op index must be a valid dimension or symbol identifier}}
affine.load %memref[%dim, %dim] : memref<?x?xf32>
"unknown.terminator"() : () -> ()
}) : () -> ()
diff --git a/mlir/test/Dialect/Affine/load-store-invalid.mlir b/mlir/test/Dialect/Affine/load-store-invalid.mlir
index 482d2f35e094923..01d6b25dee695bb 100644
--- a/mlir/test/Dialect/Affine/load-store-invalid.mlir
+++ b/mlir/test/Dialect/Affine/load-store-invalid.mlir
@@ -37,7 +37,7 @@ func.func @load_non_affine_index(%arg0 : index) {
%0 = memref.alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
%1 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
%v = affine.load %0[%1] : memref<10xf32>
}
return
@@ -50,7 +50,7 @@ func.func @store_non_affine_index(%arg0 : index) {
%1 = arith.constant 11.0 : f32
affine.for %i0 = 0 to 10 {
%2 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
affine.store %1, %0[%2] : memref<10xf32>
}
return
@@ -84,7 +84,7 @@ func.func @dma_start_non_affine_src_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op src index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op src index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -101,7 +101,7 @@ func.func @dma_start_non_affine_dst_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op dst index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op dst index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -118,7 +118,7 @@ func.func @dma_start_non_affine_tag_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op tag index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op tag index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -135,7 +135,7 @@ func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+ // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
}
return
``````````
</details>
https://github.com/llvm/llvm-project/pull/73266
More information about the Mlir-commits
mailing list