[Mlir-commits] [mlir] [mlir][tensor] Fix slice canonicalizer for out-of-bounds cases (PR #132534)
Matthias Springer
llvmlistbot at llvm.org
Sat Mar 22 01:55:45 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/132534
Since #130487, `tensor.extract_slice` and `tensor.insert_slice` ops that are statically detected to go out of bounds are rejected by the verifier.
This commit fixes canonicalization patterns that currently fold dynamically out-of-bounds ops (valid IR) to statically out-of-bounds ops (invalid IR).
>From 28fa4ffa7d91c776dccf8145bf1867a9db54953f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 22 Mar 2025 09:52:09 +0100
Subject: [PATCH] [mlir][tensor] Fix slice canonicalizer for out-of-bounds
cases
---
.../mlir/Interfaces/ViewLikeInterface.h | 38 +++++++-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 93 ++++++++++---------
mlir/lib/Interfaces/ViewLikeInterface.cpp | 58 ++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 50 ++++++++++
4 files changed, 194 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 8f07e43f847ae..e74326dba7c80 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -45,6 +45,28 @@ unsigned getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
namespace mlir {
+/// Result for slice bounds verification;
+struct SliceBoundsVerificationResult {
+ /// If set to "true", the slice bounds verification was successful.
+ bool isValid;
+ /// An error message that can be printed during op verification.
+ std::string errorMessage;
+};
+
+/// Verify that the offsets/sizes/strides-style access into the given shape
+/// is in-bounds. Only static values are verified. If `generateErrorMessage`
+/// is set to "true", an error message is produced that can be printed by the
+/// op verifier.
+SliceBoundsVerificationResult
+verifyInBoundsSlice(ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides,
+ bool generateErrorMessage = false);
+SliceBoundsVerificationResult verifyInBoundsSlice(
+ ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+ bool generateErrorMessage = false);
+
/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
/// constant arguments. This pattern assumes that the op has a suitable builder
/// that takes a result type, a "source" operand and mixed offsets, sizes and
@@ -54,7 +76,8 @@ namespace mlir {
/// returns the new result type of the op, based on the new offsets, sizes and
/// strides. `CastOpFunc` is used to generate a cast op if the result type of
/// the op has changed.
-template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
+template <typename OpType, typename ResultTypeFn, typename CastOpFunc,
+ bool CheckInBounds = false>
class OpWithOffsetSizesAndStridesConstantArgumentFolder final
: public OpRewritePattern<OpType> {
public:
@@ -72,11 +95,22 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
failed(foldDynamicIndexList(mixedStrides)))
return failure();
- // Create the new op in canonical form.
+ if (CheckInBounds) {
+ // Pattern does not apply if the produced op would not verify.
+ SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+ cast<ShapedType>(op.getSource().getType()).getShape(), mixedOffsets,
+ mixedSizes, mixedStrides);
+ if (!sliceResult.isValid)
+ return failure();
+ }
+
+ // Compute the new result type.
auto resultType =
ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
if (!resultType)
return failure();
+
+ // Create the new op in canonical form.
auto newOp =
rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
mixedOffsets, mixedSizes, mixedStrides);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2d5df07f8af4b..5f8493de991f3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
@@ -2352,37 +2353,6 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
}
-/// Verify that the offsets/sizes/strides-style access into the given tensor
-/// is in-bounds. Only static information is verified.
-static LogicalResult verifyInBoundsSlice(Operation *op,
- RankedTensorType tensorType,
- ArrayRef<int64_t> staticOffsets,
- ArrayRef<int64_t> staticSizes,
- ArrayRef<int64_t> staticStrides) {
- for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
- // Nothing to verify for dynamic source dims.
- if (tensorType.isDynamicDim(i))
- continue;
- // Nothing to verify if the offset is dynamic.
- if (ShapedType::isDynamic(staticOffsets[i]))
- continue;
- if (staticOffsets[i] >= tensorType.getDimSize(i))
- return op->emitOpError("offset ")
- << i << " is out-of-bounds: " << staticOffsets[i]
- << " >= " << tensorType.getDimSize(i);
- if (ShapedType::isDynamic(staticSizes[i]) ||
- ShapedType::isDynamic(staticStrides[i]))
- continue;
- int64_t lastPos =
- staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
- if (lastPos >= tensorType.getDimSize(i))
- return op->emitOpError("slice along dimension ")
- << i << " runs out-of-bounds: " << lastPos
- << " >= " << tensorType.getDimSize(i);
- }
- return success();
-}
-
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
@@ -2396,8 +2366,13 @@ LogicalResult ExtractSliceOp::verify() {
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
- return verifyInBoundsSlice(getOperation(), sourceType, getStaticOffsets(),
- getStaticSizes(), getStaticStrides());
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+ sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
+ getStaticStrides(), /*generateErrorMessage=*/true);
+ if (!boundsResult.isValid)
+ return getOperation()->emitError(boundsResult.errorMessage);
+
+ return success();
}
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
@@ -2470,6 +2445,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
if (!canFoldIntoConsumerOp(castOp))
return failure();
+ // Pattern does not apply if the produced op would not verify.
+ SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
+ cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
+ sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
+ sliceOp.getStaticStrides());
+ if (!sliceResult.isValid)
+ return failure();
+
// Create folded extract.
Location loc = sliceOp.getLoc();
Value newResult = rewriter.create<ExtractSliceOp>(
@@ -2634,10 +2617,10 @@ struct SliceCanonicalizer {
void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<
- OpWithOffsetSizesAndStridesConstantArgumentFolder<
- ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
- ExtractSliceOpCastFolder>(context);
+ results.add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
+ ExtractSliceOp, SliceReturnTypeCanonicalizer,
+ SliceCanonicalizer, /*CheckInBounds=*/true>,
+ ExtractSliceOpCastFolder>(context);
}
//
@@ -2775,9 +2758,14 @@ LogicalResult InsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
- // to the source tensor.
- return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
- getStaticSizes(), getStaticStrides());
+ // to the destination tensor.
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+ getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+ getStaticStrides(), /*generateErrorMessage=*/true);
+ if (!boundsResult.isValid)
+ return getOperation()->emitError(boundsResult.errorMessage);
+
+ return success();
}
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@@ -2872,6 +2860,13 @@ class InsertSliceOpConstantArgumentFolder final
failed(foldDynamicStrideList(mixedStrides)))
return failure();
+ // Pattern does not apply if the produced op would not verify.
+ SliceBoundsVerificationResult sliceResult =
+ verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
+ mixedOffsets, mixedSizes, mixedStrides);
+ if (!sliceResult.isValid)
+ return failure();
+
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
@@ -2969,10 +2964,17 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
size = srcType.getDimSize(rankReducedIdx++);
}
}
+
+ // Pattern does not apply if the produced op would not verify.
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
+ SliceBoundsVerificationResult sliceResult =
+ verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
+ mixedSizes, insertSliceOp.getMixedStrides());
+ if (!sliceResult.isValid)
+ return failure();
Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
@@ -3800,9 +3802,14 @@ LogicalResult ParallelInsertSliceOp::verify() {
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
- // to the source tensor.
- return verifyInBoundsSlice(getOperation(), getDestType(), getStaticOffsets(),
- getStaticSizes(), getStaticStrides());
+ // to the destination tensor.
+ SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
+ getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
+ getStaticStrides(), /*generateErrorMessage=*/true);
+ if (!boundsResult.isValid)
+ return getOperation()->emitError(boundsResult.errorMessage);
+
+ return success();
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 57b5cce7bb13b..70dd7b4aec88c 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -36,6 +36,64 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
return success();
}
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+ ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
+ ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
+ bool generateErrorMessage) {
+ SliceBoundsVerificationResult result;
+ result.isValid = true;
+ for (int64_t i = 0, e = shape.size(); i < e; ++i) {
+ // Nothing to verify for dynamic source dims.
+ if (ShapedType::isDynamic(shape[i]))
+ continue;
+ // Nothing to verify if the offset is dynamic.
+ if (ShapedType::isDynamic(staticOffsets[i]))
+ continue;
+ if (staticOffsets[i] >= shape[i]) {
+ result.errorMessage =
+ std::string("offset ") + std::to_string(i) +
+ " is out-of-bounds: " + std::to_string(staticOffsets[i]) +
+ " >= " + std::to_string(shape[i]);
+ result.isValid = false;
+ return result;
+ }
+ if (ShapedType::isDynamic(staticSizes[i]) ||
+ ShapedType::isDynamic(staticStrides[i]))
+ continue;
+ int64_t lastPos =
+ staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
+ if (lastPos >= shape[i]) {
+ result.errorMessage = std::string("slice along dimension ") +
+ std::to_string(i) +
+ " runs out-of-bounds: " + std::to_string(lastPos) +
+ " >= " + std::to_string(shape[i]);
+ result.isValid = false;
+ return result;
+ }
+ }
+ return result;
+}
+
+SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
+ ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
+ ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
+ bool generateErrorMessage) {
+ auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
+ SmallVector<int64_t> staticValues;
+ for (OpFoldResult ofr : ofrs) {
+ if (auto attr = dyn_cast<Attribute>(ofr)) {
+ staticValues.push_back(cast<IntegerAttr>(attr).getInt());
+ } else {
+ staticValues.push_back(ShapedType::kDynamic);
+ }
+ }
+ return staticValues;
+ };
+ return verifyInBoundsSlice(
+ shape, getStaticValues(mixedOffsets), getStaticValues(mixedSizes),
+ getStaticValues(mixedStrides), generateErrorMessage);
+}
+
LogicalResult
mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90cc0ca658ffb..fd96328c6033d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -582,6 +582,56 @@ func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<1
// -----
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+// CHECK: tensor.extract_slice %{{.*}}[0] [%{{.*}}] [1] : tensor<5xf32> to tensor<?xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<?xf32> {
+ %c10 = arith.constant 10 : index
+ %r = tensor.extract_slice %t[0] [%c10] [1] : tensor<5xf32> to tensor<?xf32>
+ return %r : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_extract_slice
+// CHECK: tensor.extract_slice %{{.*}}[0] [10] [1] : tensor<?xf32> to tensor<10xf32>
+func.func @out_of_bounds_extract_slice(%t: tensor<5xf32>) -> tensor<10xf32> {
+ %t2 = tensor.cast %t : tensor<5xf32> to tensor<?xf32>
+ %r = tensor.extract_slice %t2 [0][10][1] : tensor<?xf32> to tensor<10xf32>
+ return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [5] [1] : tensor<5xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>) -> tensor<10xf32> {
+ %c10 = arith.constant 10 : index
+ %r = tensor.insert_slice %src into %dst[%c10] [5] [1] : tensor<5xf32> into tensor<10xf32>
+ return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [%{{.*}}] [1] : tensor<?xf32> into tensor<10xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<10xf32> {
+ %src2 = tensor.cast %src : tensor<5xf32> to tensor<?xf32>
+ %r = tensor.insert_slice %src2 into %dst[7] [%sz] [1] : tensor<?xf32> into tensor<10xf32>
+ return %r : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @out_of_bounds_insert_slice
+// CHECK: tensor.insert_slice %{{.*}} into %{{.*}}[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+func.func @out_of_bounds_insert_slice(%src: tensor<5xf32>, %dst: tensor<10xf32>, %sz: index) -> tensor<?xf32> {
+ %dst2 = tensor.cast %dst : tensor<10xf32> to tensor<?xf32>
+ %r = tensor.insert_slice %src into %dst2[7] [5] [1] : tensor<5xf32> into tensor<?xf32>
+ return %r : tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
More information about the Mlir-commits
mailing list