[Mlir-commits] [mlir] [mlir][tensor] Fix insert and extract slice canonicalization (PR #72885)
Rik Huijzer
llvmlistbot at llvm.org
Mon Nov 20 07:33:38 PST 2023
https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/72885
Fixes #71150 by checking for non-negative dimensions during the `InsertSliceOpSourceCastInserter` and `ExtractSliceOp` canonicalizations. Also refactored the logic into one function so that we don't have to write a comment each time.
>From 22928e7e5da508d8d9dc8d4b7e54f84cccadef06 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 09:02:41 +0100
Subject: [PATCH 1/2] [mlir][tensor] Fix canon via `hasNegativeDimension`
---
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 6 ++++++
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 15 +++++++++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 10 ++++++++++
3 files changed, 31 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 06642adda42b381..0d027057b3a9524 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -150,6 +150,12 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
/// Tests if types are the same when ignoring encoding on ranked tensors.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
+/// Helper function to check whether the dimensions are non-negative. This
+/// check also occurs in the verifier, but we need it at later stages too
+/// because the verifier ignores dynamic dimensions, but later stages might
+/// have constant folded those to (negative) constants.
+bool hasNegativeDimension(SmallVector<int64_t> shape);
+
/// Function to control the folding of constant and extract slice.
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..3297ef673ca2e0e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -125,6 +125,12 @@ bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
return tp1 == tp2; // default implementation
}
+bool tensor::hasNegativeDimension(SmallVector<int64_t> shape) {
+ return llvm::any_of(shape, [](int64_t dim) {
+ return !ShapedType::isDynamic(dim) && dim < 0;
+ });
+}
+
/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
/// rank-extending tensor.insert_slice op.
static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
@@ -1801,6 +1807,10 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+ if (hasNegativeDimension(staticOffsets))
+ return {};
+ if (hasNegativeDimension(staticSizes))
+ return {};
return ExtractSliceOp::inferCanonicalRankReducedResultType(
desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
staticStrides);
@@ -2370,6 +2380,8 @@ class InsertSliceOpConstantArgumentFolder final
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
mixedOffsets, mixedSizes, mixedStrides);
+ if (!sourceType)
+ return failure();
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
@@ -2500,6 +2512,8 @@ struct InsertSliceOpSourceCastInserter final
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
newSrcShape[i] = *constInt;
}
+ // if (hasNegativeDimension(newSrcShape))
+ // return failure();
RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());
@@ -2521,6 +2535,7 @@ struct InsertSliceOpSourceCastInserter final
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = rewriter.create<tensor::CastOp>(
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
+
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..88f27d3d36b0471 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1102,6 +1102,16 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
// -----
+func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> {
+ %c-1 = arith.constant -1 : index
+ %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32>
+ return %e : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_negative_offset
+// CHECK: tensor.extract_slice
+
+// -----
+
func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
%c0 = arith.constant dense<42> : tensor<2x8xi32>
%0 = tensor.expand_shape %c0 [[0], [1, 2]]
>From ecef5428c160cb72103e06a160c450440ce1f416 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 16:27:53 +0100
Subject: [PATCH 2/2] Fix `insert_slice` cast inserter and refactor
---
mlir/include/mlir/Dialect/Tensor/IR/Tensor.h | 6 ------
.../mlir/Dialect/Utils/StaticValueUtils.h | 6 ++++++
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 15 ++++-----------
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 18 +++---------------
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 6 ++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++++++
6 files changed, 33 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 0d027057b3a9524..06642adda42b381 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -150,12 +150,6 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
/// Tests if types are the same when ignoring encoding on ranked tensors.
bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
-/// Helper function to check whether the dimensions are non-negative. This
-/// check also occurs in the verifier, but we need it at later stages too
-/// because the verifier ignores dynamic dimensions, but later stages might
-/// have constant folded those to (negative) constants.
-bool hasNegativeDimension(SmallVector<int64_t> shape);
-
/// Function to control the folding of constant and extract slice.
using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..9e39d81e5c4f96a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -128,6 +128,12 @@ std::pair<ArrayAttr, SmallVector<Value>>
decomposeMixedValues(Builder &b,
const SmallVectorImpl<OpFoldResult> &mixedValues);
+/// Helper function to check whether the dimensions are non-negative.
+///
+/// This is used to re-check whether dimensions are still non-negative after
+/// constant folding the dynamic dimensions.
+bool hasNegativeDimension(SmallVector<int64_t> values);
+
/// Helper to sort `values` according to matching `keys`.
SmallVector<Value>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fae8..dd75ed2500306b2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2621,17 +2621,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- // If one of the offsets or sizes is invalid, fail the canonicalization.
- // These checks also occur in the verifier, but they are needed here
- // because some dynamic dimensions may have been constant folded.
- for (int64_t offset : staticOffsets)
- if (offset < 0 && !ShapedType::isDynamic(offset))
- return {};
- for (int64_t size : staticSizes)
- if (size < 0 && !ShapedType::isDynamic(size))
- return {};
-
+ if (hasNegativeDimension(staticOffsets))
+ return {};
+ if (hasNegativeDimension(staticSizes))
+ return {};
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3297ef673ca2e0e..986e40a2e4eb34f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -125,12 +125,6 @@ bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
return tp1 == tp2; // default implementation
}
-bool tensor::hasNegativeDimension(SmallVector<int64_t> shape) {
- return llvm::any_of(shape, [](int64_t dim) {
- return !ShapedType::isDynamic(dim) && dim < 0;
- });
-}
-
/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
/// rank-extending tensor.insert_slice op.
static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
@@ -1265,13 +1259,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
SmallVector<int64_t> newShape;
operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
- for (int64_t newdim : newShape) {
- // This check also occurs in the verifier, but we need it here too
- // since intermediate passes may have replaced some dynamic dimensions
- // by constants.
- if (newdim < 0 && !ShapedType::isDynamic(newdim))
+ if (hasNegativeDimension(newShape))
return failure();
- }
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
return failure();
@@ -2512,8 +2501,8 @@ struct InsertSliceOpSourceCastInserter final
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
newSrcShape[i] = *constInt;
}
- // if (hasNegativeDimension(newSrcShape))
- // return failure();
+ if (hasNegativeDimension(newSrcShape))
+ return failure();
RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());
@@ -2535,7 +2524,6 @@ struct InsertSliceOpSourceCastInserter final
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = rewriter.create<tensor::CastOp>(
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
-
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..5d777ad74e9e852 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -200,6 +200,12 @@ decomposeMixedValues(Builder &b,
return {b.getI64ArrayAttr(staticValues), dynamicValues};
}
+bool hasNegativeDimension(SmallVector<int64_t> values) {
+ return llvm::any_of(values, [](int64_t value) {
+ return !ShapedType::isDynamic(value) && value < 0;
+ });
+}
+
/// Helper to sort `values` according to matching `keys`.
template <typename K, typename V>
static SmallVector<V>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 88f27d3d36b0471..1c0a2e868475f24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1112,6 +1112,20 @@ func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor
// -----
+func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> {
+ %c = arith.constant 0 : index
+ %const = tensor.empty(%c) : tensor<?xf32>
+ %insert_val = tensor.empty(%c) : tensor<?xf32>
+ %c-1 = arith.constant -1 : index
+ %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32>
+ return %inserted : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset
+// CHECK: %[[CAST:.*]] = tensor.cast
+// CHECK: tensor.insert_slice %[[CAST:.+]]
+
+// -----
+
func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
%c0 = arith.constant dense<42> : tensor<2x8xi32>
%0 = tensor.expand_shape %c0 [[0], [1, 2]]
More information about the Mlir-commits
mailing list