[Mlir-commits] [mlir] 3247f1e - [mlir][affine] Fix dim index out of bounds crash (#73266)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 22:35:14 PST 2023


Author: Rik Huijzer
Date: 2023-11-28T07:35:09+01:00
New Revision: 3247f1e7a281184ac0db4fc6df35232e8f1a4f12

URL: https://github.com/llvm/llvm-project/commit/3247f1e7a281184ac0db4fc6df35232e8f1a4f12
DIFF: https://github.com/llvm/llvm-project/commit/3247f1e7a281184ac0db4fc6df35232e8f1a4f12.diff

LOG: [mlir][affine] Fix dim index out of bounds crash (#73266)

This PR suggests a way to fix
https://github.com/llvm/llvm-project/issues/70418. It now throws an
error if the `index` operand for `memref.dim` is out of bounds. Catching
it in the verifier was not possible because the constant value is not
yet available at that point. Unfortunately, the error is not very
descriptive since it was only possible to propagate boolean up.

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
    mlir/test/Dialect/Affine/invalid.mlir
    mlir/test/Dialect/Affine/load-store-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d22a7539fb75018..a7fc7ddec26e618 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -319,7 +319,13 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
 template <typename AnyMemRefDefOp>
 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
                                     Region *region) {
-  auto memRefType = memrefDefOp.getType();
+  MemRefType memRefType = memrefDefOp.getType();
+
+  // Dimension index is out of bounds.
+  if (index >= memRefType.getRank()) {
+    return false;
+  }
+
   // Statically shaped.
   if (!memRefType.isDynamicDim(index))
     return true;
@@ -1651,19 +1657,22 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
     if (!idx.getType().isIndex())
       return emitOpError("src index to dma_start must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("src index must be a dimension or symbol identifier");
+      return emitOpError(
+          "src index must be a valid dimension or symbol identifier");
   }
   for (auto idx : getDstIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("dst index to dma_start must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("dst index must be a dimension or symbol identifier");
+      return emitOpError(
+          "dst index must be a valid dimension or symbol identifier");
   }
   for (auto idx : getTagIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("tag index to dma_start must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("tag index must be a dimension or symbol identifier");
+      return emitOpError(
+          "tag index must be a valid dimension or symbol identifier");
   }
   return success();
 }
@@ -1752,7 +1761,8 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
     if (!idx.getType().isIndex())
       return emitOpError("index to dma_wait must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("index must be a dimension or symbol identifier");
+      return emitOpError(
+          "index must be a valid dimension or symbol identifier");
   }
   return success();
 }
@@ -2913,8 +2923,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
-                               SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
   composeSetAndOperands(set, operands);
@@ -3005,18 +3014,19 @@ static LogicalResult
 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
                        Operation::operand_range mapOperands,
                        MemRefType memrefType, unsigned numIndexOperands) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != memrefType.getRank())
-      return op->emitOpError("affine map num results must equal memref rank");
-    if (map.getNumInputs() != numIndexOperands)
-      return op->emitOpError("expects as many subscripts as affine map inputs");
+  AffineMap map = mapAttr.getValue();
+  if (map.getNumResults() != memrefType.getRank())
+    return op->emitOpError("affine map num results must equal memref rank");
+  if (map.getNumInputs() != numIndexOperands)
+    return op->emitOpError("expects as many subscripts as affine map inputs");
 
   Region *scope = getAffineScope(op);
   for (auto idx : mapOperands) {
     if (!idx.getType().isIndex())
       return op->emitOpError("index to load must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return op->emitOpError("index must be a dimension or symbol identifier");
+      return op->emitOpError(
+          "index must be a valid dimension or symbol identifier");
   }
 
   return success();
@@ -3605,7 +3615,8 @@ LogicalResult AffinePrefetchOp::verify() {
   Region *scope = getAffineScope(*this);
   for (auto idx : getMapOperands()) {
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("index must be a dimension or symbol identifier");
+      return emitOpError(
+          "index must be a valid dimension or symbol identifier");
   }
   return success();
 }

diff  --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index 759ab2d6c358c8a..a09f1697fd72494 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -49,3 +49,15 @@ func.func @call_functions(%arg0: index) -> index {
 }
 
 // -----
+
+func.func @dim_index_out_of_bounds() {
+  %c6 = arith.constant 6 : index
+  %alloc_4 = memref.alloc() : memref<4xi64>
+  %dim = memref.dim %alloc_4, %c6 : memref<4xi64>
+  %alloca_100 = memref.alloca() : memref<100xi64>
+  // expected-error at +1 {{'affine.vector_load' op index must be a valid dimension or symbol identifier}}
+  %70 = affine.vector_load %alloca_100[%dim] : memref<100xi64>, vector<31xi64>
+  return
+}
+
+// -----

diff  --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 72864516b459a51..60f13102f551569 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -55,7 +55,7 @@ func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
   "unknown"() ({
   ^bb0(%arg: index):
     affine.load %M[%arg] : memref<10xi32>
-    // expected-error at -1 {{index must be a dimension or symbol identifier}}
+    // expected-error at -1 {{index must be a valid dimension or symbol identifier}}
     cf.br ^bb1
   ^bb1:
     cf.br ^bb1
@@ -521,7 +521,7 @@ func.func @dynamic_dimension_index() {
     %idx = "unknown.test"() : () -> (index)
     %memref = "unknown.test"() : () -> memref<?x?xf32>
     %dim = memref.dim %memref, %idx : memref<?x?xf32>
-    // expected-error @below {{op index must be a dimension or symbol identifier}}
+    // expected-error @below {{op index must be a valid dimension or symbol identifier}}
     affine.load %memref[%dim, %dim] : memref<?x?xf32>
     "unknown.terminator"() : () -> ()
   }) : () -> ()

diff  --git a/mlir/test/Dialect/Affine/load-store-invalid.mlir b/mlir/test/Dialect/Affine/load-store-invalid.mlir
index 482d2f35e094923..01d6b25dee695bb 100644
--- a/mlir/test/Dialect/Affine/load-store-invalid.mlir
+++ b/mlir/test/Dialect/Affine/load-store-invalid.mlir
@@ -37,7 +37,7 @@ func.func @load_non_affine_index(%arg0 : index) {
   %0 = memref.alloc() : memref<10xf32>
   affine.for %i0 = 0 to 10 {
     %1 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+    // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
     %v = affine.load %0[%1] : memref<10xf32>
   }
   return
@@ -50,7 +50,7 @@ func.func @store_non_affine_index(%arg0 : index) {
   %1 = arith.constant 11.0 : f32
   affine.for %i0 = 0 to 10 {
     %2 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+    // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
     affine.store %1, %0[%2] : memref<10xf32>
   }
   return
@@ -84,7 +84,7 @@ func.func @dma_start_non_affine_src_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op src index must be a dimension or symbol identifier}}
+    // expected-error at +1 {{op src index must be a valid dimension or symbol identifier}}
     affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
         : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
   }
@@ -101,7 +101,7 @@ func.func @dma_start_non_affine_dst_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op dst index must be a dimension or symbol identifier}}
+    // expected-error at +1 {{op dst index must be a valid dimension or symbol identifier}}
     affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
         : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
   }
@@ -118,7 +118,7 @@ func.func @dma_start_non_affine_tag_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op tag index must be a dimension or symbol identifier}}
+    // expected-error at +1 {{op tag index must be a valid dimension or symbol identifier}}
     affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
         : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
   }
@@ -135,7 +135,7 @@ func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error at +1 {{op index must be a dimension or symbol identifier}}
+    // expected-error at +1 {{op index must be a valid dimension or symbol identifier}}
     affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
   }
   return


        


More information about the Mlir-commits mailing list