[flang] [mlir] [compiler-rt] [clang-tools-extra] [clang] [libcxx] [llvm] [mlir] Fix a zero stride canonicalizer crash (PR #74200)

Rik Huijzer via cfe-commits cfe-commits at lists.llvm.org
Tue Dec 5 01:37:38 PST 2023


https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/74200

>From 22928e7e5da508d8d9dc8d4b7e54f84cccadef06 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 09:02:41 +0100
Subject: [PATCH 1/6] [mlir][tensor] Fix canon via `hasNegativeDimension`

---
 mlir/include/mlir/Dialect/Tensor/IR/Tensor.h |  6 ++++++
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp     | 15 +++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir   | 10 ++++++++++
 3 files changed, 31 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 06642adda42b3..0d027057b3a95 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -150,6 +150,12 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
 /// Tests if types are the same when ignoring encoding on ranked tensors.
 bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
 
+/// Helper function to check whether the dimensions are non-negative. This
+/// check also occurs in the verifier, but we need it at later stages too
+/// because the verifier ignores dynamic dimensions, but later stages might
+/// have constant folded those to (negative) constants.
+bool hasNegativeDimension(SmallVector<int64_t> shape);
+
 /// Function to control the folding of constant and extract slice.
 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e18..3297ef673ca2e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -125,6 +125,12 @@ bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
   return tp1 == tp2; // default implementation
 }
 
+bool tensor::hasNegativeDimension(SmallVector<int64_t> shape) {
+  return llvm::any_of(shape, [](int64_t dim) {
+    return !ShapedType::isDynamic(dim) && dim < 0;
+  });
+}
+
 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
 /// rank-extending tensor.insert_slice op.
 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
@@ -1801,6 +1807,10 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  if (hasNegativeDimension(staticOffsets))
+    return {};
+  if (hasNegativeDimension(staticSizes))
+    return {};
   return ExtractSliceOp::inferCanonicalRankReducedResultType(
       desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
@@ -2370,6 +2380,8 @@ class InsertSliceOpConstantArgumentFolder final
     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
         mixedOffsets, mixedSizes, mixedStrides);
+    if (!sourceType)
+      return failure();
     Value toInsert = insertSliceOp.getSource();
     if (sourceType != insertSliceOp.getSourceType()) {
       OpBuilder::InsertionGuard g(rewriter);
@@ -2500,6 +2512,8 @@ struct InsertSliceOpSourceCastInserter final
               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
         newSrcShape[i] = *constInt;
     }
+    // if (hasNegativeDimension(newSrcShape))
+    //  return failure();
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());
@@ -2521,6 +2535,7 @@ struct InsertSliceOpSourceCastInserter final
       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
     Value cast = rewriter.create<tensor::CastOp>(
         insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
+
     rewriter.replaceOpWithNewOp<InsertOpTy>(
         insertSliceOp, cast, insertSliceOp.getDest(),
         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c1..88f27d3d36b04 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1102,6 +1102,16 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
 
 // -----
 
+func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> {
+  %c-1 = arith.constant -1 : index
+  %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32>
+  return %e : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_negative_offset
+// CHECK: tensor.extract_slice
+
+// -----
+
 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
   %c0 = arith.constant dense<42> : tensor<2x8xi32>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]

>From ecef5428c160cb72103e06a160c450440ce1f416 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 16:27:53 +0100
Subject: [PATCH 2/6] Fix `insert_slice` cast inserter and refactor

---
 mlir/include/mlir/Dialect/Tensor/IR/Tensor.h   |  6 ------
 .../mlir/Dialect/Utils/StaticValueUtils.h      |  6 ++++++
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp       | 15 ++++-----------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp       | 18 +++---------------
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp    |  6 ++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir     | 14 ++++++++++++++
 6 files changed, 33 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 0d027057b3a95..06642adda42b3 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -150,12 +150,6 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
 /// Tests if types are the same when ignoring encoding on ranked tensors.
 bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
 
-/// Helper function to check whether the dimensions are non-negative. This
-/// check also occurs in the verifier, but we need it at later stages too
-/// because the verifier ignores dynamic dimensions, but later stages might
-/// have constant folded those to (negative) constants.
-bool hasNegativeDimension(SmallVector<int64_t> shape);
-
 /// Function to control the folding of constant and extract slice.
 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
 
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd..9e39d81e5c4f9 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -128,6 +128,12 @@ std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedValues(Builder &b,
                      const SmallVectorImpl<OpFoldResult> &mixedValues);
 
+/// Helper function to check whether the dimensions are non-negative.
+///
+/// This is used to re-check whether dimensions are still non-negative after
+/// constant folding the dynamic dimensions.
+bool hasNegativeDimension(SmallVector<int64_t> values);
+
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fa..dd75ed2500306 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2621,17 +2621,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  // If one of the offsets or sizes is invalid, fail the canonicalization.
-  // These checks also occur in the verifier, but they are needed here
-  // because some dynamic dimensions may have been constant folded.
-  for (int64_t offset : staticOffsets)
-    if (offset < 0 && !ShapedType::isDynamic(offset))
-      return {};
-  for (int64_t size : staticSizes)
-    if (size < 0 && !ShapedType::isDynamic(size))
-      return {};
-
+  if (hasNegativeDimension(staticOffsets))
+    return {};
+  if (hasNegativeDimension(staticSizes))
+    return {};
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3297ef673ca2e..986e40a2e4eb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -125,12 +125,6 @@ bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
   return tp1 == tp2; // default implementation
 }
 
-bool tensor::hasNegativeDimension(SmallVector<int64_t> shape) {
-  return llvm::any_of(shape, [](int64_t dim) {
-    return !ShapedType::isDynamic(dim) && dim < 0;
-  });
-}
-
 /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
 /// rank-extending tensor.insert_slice op.
 static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
@@ -1265,13 +1259,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     SmallVector<int64_t> newShape;
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
-    for (int64_t newdim : newShape) {
-      // This check also occurs in the verifier, but we need it here too
-      // since intermediate passes may have replaced some dynamic dimensions
-      // by constants.
-      if (newdim < 0 && !ShapedType::isDynamic(newdim))
+    if (hasNegativeDimension(newShape))
         return failure();
-    }
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();
@@ -2512,8 +2501,8 @@ struct InsertSliceOpSourceCastInserter final
               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
         newSrcShape[i] = *constInt;
     }
-    // if (hasNegativeDimension(newSrcShape))
-    //  return failure();
+    if (hasNegativeDimension(newSrcShape))
+      return failure();
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());
@@ -2535,7 +2524,6 @@ struct InsertSliceOpSourceCastInserter final
       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
     Value cast = rewriter.create<tensor::CastOp>(
         insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
-
     rewriter.replaceOpWithNewOp<InsertOpTy>(
         insertSliceOp, cast, insertSliceOp.getDest(),
         insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a..5d777ad74e9e8 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -200,6 +200,12 @@ decomposeMixedValues(Builder &b,
   return {b.getI64ArrayAttr(staticValues), dynamicValues};
 }
 
+bool hasNegativeDimension(SmallVector<int64_t> values) {
+  return llvm::any_of(values, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value < 0;
+  });
+}
+
 /// Helper to sort `values` according to matching `keys`.
 template <typename K, typename V>
 static SmallVector<V>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 88f27d3d36b04..1c0a2e868475f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1112,6 +1112,20 @@ func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor
 
 // -----
 
+func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> {
+  %c = arith.constant 0 : index
+  %const = tensor.empty(%c) : tensor<?xf32>
+  %insert_val = tensor.empty(%c) : tensor<?xf32>
+  %c-1 = arith.constant -1 : index
+  %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32>
+  return %inserted : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset
+// CHECK: %[[CAST:.*]] = tensor.cast
+// CHECK: tensor.insert_slice %[[CAST:.+]]
+
+// -----
+
 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
   %c0 = arith.constant dense<42> : tensor<2x8xi32>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]

>From 69637ad2b8915f352c6dae3cab838a04b84c3e10 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 20 Nov 2023 16:40:09 +0100
Subject: [PATCH 3/6] Apply `clang-format`

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 986e40a2e4eb3..04a8e43a639f4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1260,7 +1260,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
     if (hasNegativeDimension(newShape))
-        return failure();
+      return failure();
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();

>From ecd074dc485485ebf6b7ae7aa5ee52cb397994ca Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 2 Dec 2023 18:02:31 +0100
Subject: [PATCH 4/6] Refactor

---
 .../mlir/Dialect/Utils/StaticValueUtils.h     | 36 ++++++++++++++-----
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  7 ++--
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 16 +++------
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp   | 33 +++++++++++++----
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 12 +++++++
 5 files changed, 75 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 768f0ac1abe56..a1853438ccf7f 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -128,12 +128,6 @@ std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedValues(Builder &b,
                      const SmallVectorImpl<OpFoldResult> &mixedValues);
 
-/// Helper function to check whether the dimensions are non-negative.
-///
-/// This is used to re-check whether dimensions are still non-negative after
-/// constant folding the dynamic dimensions.
-bool hasNegativeDimension(SmallVector<int64_t> values);
-
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
@@ -145,12 +139,36 @@ SmallVector<int64_t>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
+/// Helper function to check whether the passed in `sizes` or `values` are
+/// valid. This can be used to re-check whether dimensions are still valid
+/// after constant folding the dynamic dimensions.
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
+
+/// Helper function to check whether the passed in `strides` are valid. This
+/// can be used to re-check whether dimensions are still valid after constant
+/// folding the dynamic dimensions.
+bool hasValidStrides(SmallVector<int64_t> strides);
+
 /// Returns "success" when any of the elements in `ofrs` is a constant value. In
 /// that case the value is replaced by an attribute. Returns "failure" when no
-/// folding happened. If `onlyNonNegative` is set, only non-negative constant
-/// values are folded.
+/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
+/// non-negative and non-zero constant values are folded respectively.
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative = false);
+                                   bool onlyNonNegative = false,
+                                   bool onlyNonZero = false);
+
+/// Returns "success" when any of the elements in `OffsetsOrSizes` is a
+/// constant value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);
+
+/// Returns "success" when any of the elements in `strides` is a constant
+/// value. In that case the value is replaced by an attribute. Returns
+/// "failure" when no folding happened. Invalid values are not folded to avoid
+/// canonicalization crashes.
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);
 
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f222011a2edf5..c6d947a2427db 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include <cstdint>
 
 using namespace mlir;
 using namespace mlir::memref;
@@ -2581,9 +2582,11 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-  if (hasNegativeDimension(staticOffsets))
+  if (!hasValidSizesOffsets(staticOffsets))
     return {};
-  if (hasNegativeDimension(staticSizes))
+  if (!hasValidSizesOffsets(staticSizes))
+    return {};
+  if (!hasValidStrides(staticStrides))
     return {};
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index eab1d261b1064..94b7b734f88fe 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1446,7 +1446,7 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     SmallVector<int64_t> newShape;
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
-    if (hasNegativeDimension(newShape))
+    if (!hasValidSizesOffsets(newShape))
       return failure();
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
@@ -1983,10 +1983,6 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-  if (hasNegativeDimension(staticOffsets))
-    return {};
-  if (hasNegativeDimension(staticSizes))
-    return {};
   return ExtractSliceOp::inferCanonicalRankReducedResultType(
       desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
@@ -2547,17 +2543,15 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
-        failed(foldDynamicIndexList(mixedStrides)))
+    if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
+        failed(foldDynamicOffsetSizeList(mixedSizes)) &&
+        failed(foldDynamicStrideList(mixedStrides)))
       return failure();
 
     // Create the new op in canonical form.
     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
         mixedOffsets, mixedSizes, mixedStrides);
-    if (!sourceType)
-      return failure();
     Value toInsert = insertSliceOp.getSource();
     if (sourceType != insertSliceOp.getSourceType()) {
       OpBuilder::InsertionGuard g(rewriter);
@@ -2692,7 +2686,7 @@ struct InsertSliceOpSourceCastInserter final
         newSrcShape[i] = *constInt;
       }
     }
-    if (hasNegativeDimension(newSrcShape))
+    if (!hasValidSizesOffsets(newSrcShape))
       return failure();
 
     RankedTensorType newSrcType =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 4f606e17a4d59..0c8a88da789e2 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -200,12 +200,6 @@ decomposeMixedValues(Builder &b,
   return {b.getI64ArrayAttr(staticValues), dynamicValues};
 }
 
-bool hasNegativeDimension(SmallVector<int64_t> values) {
-  return llvm::any_of(values, [](int64_t value) {
-    return !ShapedType::isDynamic(value) && value < 0;
-  });
-}
-
 /// Helper to sort `values` according to matching `keys`.
 template <typename K, typename V>
 static SmallVector<V>
@@ -262,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
+bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
+  return llvm::none_of(sizesOrOffsets, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value < 0;
+  });
+}
+
+bool hasValidStrides(SmallVector<int64_t> strides) {
+  return llvm::none_of(strides, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value == 0;
+  });
+}
+
 LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
-                                   bool onlyNonNegative) {
+                                   bool onlyNonNegative, bool onlyNonZero) {
   bool valuesChanged = false;
   for (OpFoldResult &ofr : ofrs) {
     if (ofr.is<Attribute>())
@@ -273,6 +279,8 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
       // Note: All ofrs have index type.
       if (onlyNonNegative && *getConstantIntValue(attr) < 0)
         continue;
+      if (onlyNonZero && *getConstantIntValue(attr) == 0)
+        continue;
       ofr = attr;
       valuesChanged = true;
     }
@@ -280,4 +288,15 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
   return success(valuesChanged);
 }
 
+LogicalResult
+foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
+  return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
+                              /*onlyNonZero=*/false);
+}
+
+LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
+  return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
+                              /*onlyNonZero=*/true);
+}
+
 } // namespace mlir
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a1f8673638ff8..d3406c630f6dd 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?
 
 // -----
 
+// CHECK-LABEL: func @no_fold_subview_zero_stride
+//  CHECK:        %[[SUBVIEW:.+]] = memref.subview
+//  CHECK:        return %[[SUBVIEW]]
+func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
+  return %1 : memref<1xf32, strided<[?], offset: 1>>
+}
+
+// -----
+
 // CHECK-LABEL: func @no_fold_of_store
 //  CHECK:   %[[cst:.+]] = memref.cast %arg
 //  CHECK:   memref.store %[[cst]]

>From 9a577af49dfc360587a4e45195a6a26b75eab083 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Sat, 2 Dec 2023 18:06:05 +0100
Subject: [PATCH 5/6] Cleanup

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   |  1 -
 mlir/test/Dialect/Tensor/canonicalize.mlir | 24 ----------------------
 2 files changed, 25 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c6d947a2427db..b2d52e400e52d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -22,7 +22,6 @@
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
-#include <cstdint>
 
 using namespace mlir;
 using namespace mlir::memref;
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 77978e0896a28..84c44a09aa3dd 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1179,30 +1179,6 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
 
 // -----
 
-func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> {
-  %c-1 = arith.constant -1 : index
-  %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32>
-  return %e : tensor<?xf32>
-}
-// CHECK-LABEL: func @no_fold_extract_slice_negative_offset
-// CHECK: tensor.extract_slice
-
-// -----
-
-func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> {
-  %c = arith.constant 0 : index
-  %const = tensor.empty(%c) : tensor<?xf32>
-  %insert_val = tensor.empty(%c) : tensor<?xf32>
-  %c-1 = arith.constant -1 : index
-  %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32>
-  return %inserted : tensor<?xf32>
-}
-// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset
-// CHECK: %[[CAST:.*]] = tensor.cast
-// CHECK: tensor.insert_slice %[[CAST:.+]]
-
-// -----
-
 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
   %c0 = arith.constant dense<42> : tensor<2x8xi32>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]

>From efcb3778a6c1b43ce7efdf97973e54e52f638a32 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 4 Dec 2023 08:18:03 +0100
Subject: [PATCH 6/6] Apply suggestions from code review

Co-authored-by: Kai Sasaki <lewuathe at gmail.com>
---
 mlir/include/mlir/Dialect/Utils/StaticValueUtils.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index a1853438ccf7f..1dc0398494dcc 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -139,7 +139,7 @@ SmallVector<int64_t>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
-/// Helper function to check whether the passed in `sizes` or `values` are
+/// Helper function to check whether the passed in `sizes` or `offsets` are
 /// valid. This can be used to re-check whether dimensions are still valid
 /// after constant folding the dynamic dimensions.
 bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);
@@ -157,7 +157,7 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
                                    bool onlyNonNegative = false,
                                    bool onlyNonZero = false);
 
-/// Returns "success" when any of the elements in `OffsetsOrSizes` is a
+/// Returns "success" when any of the elements in `offsetsOrSizes` is a
 /// constant value. In that case the value is replaced by an attribute. Returns
 /// "failure" when no folding happened. Invalid values are not folded to avoid
 /// canonicalization crashes.



More information about the cfe-commits mailing list