[Mlir-commits] [mlir] [mlir][tensor] Fix slice canonicalizer for out-of-bounds cases (PR #132534)

Matthias Springer llvmlistbot at llvm.org
Sat Mar 22 01:55:45 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/132534

Since #130487, `tensor.extract_slice` and `tensor.insert_slice` ops that are statically detected to go out of bounds are rejected by the verifier.

This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR).


>From 28fa4ffa7d91c776dccf8145bf1867a9db54953f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 22 Mar 2025 09:52:09 +0100
Subject: [PATCH] [mlir][tensor] Fix slice canonicalizer for out-of-bounds
 cases

---
 .../mlir/Interfaces/ViewLikeInterface.h       | 38 +++++++-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 93 ++++++++++---------
 mlir/lib/Interfaces/ViewLikeInterface.cpp     | 58 ++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 50 ++++++++++
 4 files changed, 194 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 8f07e43f847ae..e74326dba7c80 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
 
 namespace mlir {
 
+/// Result for slice bounds verification;
+struct SliceBoundsVerificationResult {
+  /// If set to "true", the slice bounds verification was successful.
+  bool isValid;
+  /// An error message that can be printed during op verification.
+  std::string errorMessage;
+};
+
+/// Verify that the offsets/sizes/strides-style access into the given shape
+/// is in-bounds. Only static values are verified. If `generateErrorMessage`
+/// is set to "true", an error message is produced that can be printed by the
+///  op verifier.
+SliceBoundsVerificationResult
+verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+                    ArrayRef<int64_t> staticSizes,
+                    ArrayRef<int64_t> staticStrides,
+                    bool generateErrorMessage = false);
+SliceBoundsVerificationResult verifyInBoundsSlice(
+    ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+    ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+    bool generateErrorMessage = false);
+
 /// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
 /// constant arguments. This pattern assumes that the op has a suitable builder
 /// that takes a result type, a "source" operand and mixed offsets, sizes and
@@ -54,7 +76,8 @@ namespace mlir {
 /// returns the new result type of the op, based on the new offsets, sizes and
 /// strides. `CastOpFunc` is used to generate a cast op if the result type of
 /// the op has changed.
-template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
+template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
+          bool CheckInBounds = false>
 class OpWithOffsetSizesAndStridesConstantArgumentFolder final
     : public OpRewritePattern<OpType> {
 public:
@@ -72,11 +95,22 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
         failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
-    // Create the new op in canonical form.
+    if (CheckInBounds) {
+      // Pattern does not apply if the produced op would not verify.
+      SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+          cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
+          mixedSizes, mixedStrides);
+      if (!sliceResult.isValid)
+        return failure();
+    }
+
+    // Compute the new result type.
     auto resultType =
         ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
     if (!resultType)
       return failure();
+
+    // Create the new op in canonical form.
     auto newOp =
         rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
                                 mixedOffsets, mixedSizes, mixedStrides);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2d5df07f8af4b..5f8493de991f3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
   }
 }
 
-/// Verify that the offsets/sizes/strides-style access into the given tensor
-/// is in-bounds. Only static information is verified.
-static LogicalResult verifyInBoundsSlice(Operation *op,
-                                         RankedTensorType tensorType,
-                                         ArrayRef<int64_t> staticOffsets,
-                                         ArrayRef<int64_t> staticSizes,
-                                         ArrayRef<int64_t> staticStrides) {
-  for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
-    // Nothing to verify for dynamic source dims.
-    if (tensorType.isDynamicDim(i))
-      continue;
-    // Nothing to verify if the offset is dynamic.
-    if (ShapedType::isDynamic(staticOffsets[i]))
-      continue;
-    if (staticOffsets[i] >= tensorType.getDimSize(i))
-      return op->emitOpError("offset ")
-             << i << " is out-of-bounds: " << staticOffsets[i]
-             << " >= " << tensorType.getDimSize(i);
-    if (ShapedType::isDynamic(staticSizes[i]) ||
-        ShapedType::isDynamic(staticStrides[i]))
-      continue;
-    int64_t lastPos =
-        staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
-    if (lastPos >= tensorType.getDimSize(i))
-      return op->emitOpError("slice along dimension ")
-             << i << " runs out-of-bounds: " << lastPos
-             << " >= " << tensorType.getDimSize(i);
-  }
-  return success();
-}
-
 /// Verifier for ExtractSliceOp.
 LogicalResult ExtractSliceOp::verify() {
   RankedTensorType sourceType = getSourceType();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
 
   // Verify that offsets, sizes, strides do not run out-of-bounds with respect
   // to the source tensor.
-  return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
-                             getStaticSizes(), getStaticStrides());
+  SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+      sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
+      getStaticStrides(), /*generateErrorMessage=*/true);
+  if (!boundsResult.isValid)
+    return getOperation()->emitError(boundsResult.errorMessage);
+
+  return success();
 }
 
 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
     if (!canFoldIntoConsumerOp(castOp))
       return failure();
 
+    // Pattern does not apply if the produced op would not verify.
+    SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+        cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
+        sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
+        sliceOp.getStaticStrides());
+    if (!sliceResult.isValid)
+      return failure();
+
     // Create folded extract.
     Location loc = sliceOp.getLoc();
     Value newResult = rewriter.create<ExtractSliceOp>(
@@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
 
 void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
-  results.add<
-      OpWithOffsetSizesAndStridesConstantArgumentFolder<
-          ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
-      ExtractSliceOpCastFolder>(context);
+  results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+                  ExtractSliceOp, SliceReturnTypeCanonicalizer,
+                  SliceCanonicalizer, /*CheckInBounds=*/true>,
+              ExtractSliceOpCastFolder>(context);
 }
 
 //
@@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
     return produceSliceErrorMsg(result, *this, expectedType);
 
   // Verify that offsets, sizes, strides do not run out-of-bounds with respect
-  // to the source tensor.
-  return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
-                             getStaticSizes(), getStaticStrides());
+  // to the destination tensor.
+  SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+      getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+      getStaticStrides(), /*generateErrorMessage=*/true);
+  if (!boundsResult.isValid)
+    return getOperation()->emitError(boundsResult.errorMessage);
+
+  return success();
 }
 
 /// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
         failed(foldDynamicStrideList(mixedStrides)))
       return failure();
 
+    // Pattern does not apply if the produced op would not verify.
+    SliceBoundsVerificationResult sliceResult =
+        verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
+                            mixedOffsets, mixedSizes, mixedStrides);
+    if (!sliceResult.isValid)
+      return failure();
+
     // Create the new op in canonical form.
     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
@@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
         size = srcType.getDimSize(rankReducedIdx++);
       }
     }
+
+    // Pattern does not apply if the produced op would not verify.
     if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
                             staticSizes, insertSliceOp.getStaticStrides()) !=
         SliceVerificationResult::Success)
       return failure();
+    SliceBoundsVerificationResult sliceResult =
+        verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
+                            mixedSizes, insertSliceOp.getMixedStrides());
+    if (!sliceResult.isValid)
+      return failure();
 
     Operation *replacement = rewriter.create<InsertOpTy>(
         insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
@@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
     return produceSliceErrorMsg(result, *this, expectedType);
 
   // Verify that offsets, sizes, strides do not run out-of-bounds with respect
-  // to the source tensor.
-  return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
-                             getStaticSizes(), getStaticStrides());
+  // to the destination tensor.
+  SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+      getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+      getStaticStrides(), /*generateErrorMessage=*/true);
+  if (!boundsResult.isValid)
+    return getOperation()->emitError(boundsResult.errorMessage);
+
+  return success();
 }
 
 void ParallelInsertSliceOp::getCanonicalizationPatterns(
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 57b5cce7bb13b..70dd7b4aec88c 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
   return success();
 }
 
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+    ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+    ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
+    bool generateErrorMessage) {
+  SliceBoundsVerificationResult result;
+  result.isValid = true;
+  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
+    // Nothing to verify for dynamic source dims.
+    if (ShapedType::isDynamic(shape[i]))
+      continue;
+    // Nothing to verify if the offset is dynamic.
+    if (ShapedType::isDynamic(staticOffsets[i]))
+      continue;
+    if (staticOffsets[i] >= shape[i]) {
+      result.errorMessage =
+          std::string("offset ") + std::to_string(i) +
+          " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
+          " >= " + std::to_string(shape[i]);
+      result.isValid = false;
+      return result;
+    }
+    if (ShapedType::isDynamic(staticSizes[i]) ||
+        ShapedType::isDynamic(staticStrides[i]))
+      continue;
+    int64_t lastPos =
+        staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
+    if (lastPos >= shape[i]) {
+      result.errorMessage = std::string("slice along dimension ") +
+                            std::to_string(i) +
+                            " runs out-of-bounds: " + std::to_string(lastPos) +
+                            " >= " + std::to_string(shape[i]);
+      result.isValid = false;
+      return result;
+    }
+  }
+  return result;
+}
+
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+    ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+    ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+    bool generateErrorMessage) {
+  auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
+    SmallVector<int64_t> staticValues;
+    for (OpFoldResult ofr : ofrs) {
+      if (auto attr = dyn_cast<Attribute>(ofr)) {
+        staticValues.push_back(cast<IntegerAttr>(attr).getInt());
+      } else {
+        staticValues.push_back(ShapedType::kDynamic);
+      }
+    }
+    return staticValues;
+  };
+  return verifyInBoundsSlice(
+      shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
+      getStaticValues(mixedStrides), generateErrorMessage);
+}
+
 LogicalResult
 mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
   std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90cc0ca658ffb..fd96328c6033d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1
 
 // -----
 
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+//       CHECK:   tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
+  %c10 = arith.constant 10 : index
+  %r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
+  return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+//       CHECK:   tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
+  %t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
+  %r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
+  return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+//       CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
+  %c10 = arith.constant 10 : index
+  %r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
+  return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+//       CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
+  %src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
+  %r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
+  return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+//       CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
+  %dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
+  %r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+  return %r : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
 //  CHECK-SAME:   %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
 //  CHECK-SAME:   %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>



More information about the Mlir-commits mailing list