[Mlir-commits] [mlir] [mlir][tensor] Fix insert and extract slice canonicalization (PR #72885)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 20 07:34:08 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Rik Huijzer (rikhuijzer)

<details>
<summary>Changes</summary>

Fixes #<!-- -->71150 by checking for non-negative dimensions during the `InsertSliceOpSourceCastInserter` and `ExtractSliceOp` canonicalizations. Also refactored the logic into one function so that we don't have to write a comment each time.

---
Full diff: https://github.com/llvm/llvm-project/pull/72885.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+6) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+4-11) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+9-6) 
- (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+6) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+24) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 23a366036b9dd6f..9e39d81e5c4f96a 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -128,6 +128,12 @@ std::pair<ArrayAttr, SmallVector<Value>>
 decomposeMixedValues(Builder &b,
                      const SmallVectorImpl<OpFoldResult> &mixedValues);
 
+/// Helper function to check whether the dimensions are non-negative.
+///
+/// This is used to re-check whether dimensions are still non-negative after
+/// constant folding the dynamic dimensions.
+bool hasNegativeDimension(SmallVector<int64_t> values);
+
 /// Helper to sort `values` according to matching `keys`.
 SmallVector<Value>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fae8..dd75ed2500306b2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2621,17 +2621,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  // If one of the offsets or sizes is invalid, fail the canonicalization.
-  // These checks also occur in the verifier, but they are needed here
-  // because some dynamic dimensions may have been constant folded.
-  for (int64_t offset : staticOffsets)
-    if (offset < 0 && !ShapedType::isDynamic(offset))
-      return {};
-  for (int64_t size : staticSizes)
-    if (size < 0 && !ShapedType::isDynamic(size))
-      return {};
-
+  if (hasNegativeDimension(staticOffsets))
+    return {};
+  if (hasNegativeDimension(staticSizes))
+    return {};
   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
                                     staticSizes, staticStrides);
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..986e40a2e4eb34f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1259,13 +1259,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
     SmallVector<int64_t> newShape;
     operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
 
-    for (int64_t newdim : newShape) {
-      // This check also occurs in the verifier, but we need it here too
-      // since intermediate passes may have replaced some dynamic dimensions
-      // by constants.
-      if (newdim < 0 && !ShapedType::isDynamic(newdim))
+    if (hasNegativeDimension(newShape))
         return failure();
-    }
 
     if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
       return failure();
@@ -1801,6 +1796,10 @@ RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  if (hasNegativeDimension(staticOffsets))
+    return {};
+  if (hasNegativeDimension(staticSizes))
+    return {};
   return ExtractSliceOp::inferCanonicalRankReducedResultType(
       desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
       staticStrides);
@@ -2370,6 +2369,8 @@ class InsertSliceOpConstantArgumentFolder final
     auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
         insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
         mixedOffsets, mixedSizes, mixedStrides);
+    if (!sourceType)
+      return failure();
     Value toInsert = insertSliceOp.getSource();
     if (sourceType != insertSliceOp.getSourceType()) {
       OpBuilder::InsertionGuard g(rewriter);
@@ -2500,6 +2501,8 @@ struct InsertSliceOpSourceCastInserter final
               getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
         newSrcShape[i] = *constInt;
     }
+    if (hasNegativeDimension(newSrcShape))
+      return failure();
 
     RankedTensorType newSrcType =
         RankedTensorType::get(newSrcShape, srcType.getElementType());
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8a4ccc990331a7f..5d777ad74e9e852 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -200,6 +200,12 @@ decomposeMixedValues(Builder &b,
   return {b.getI64ArrayAttr(staticValues), dynamicValues};
 }
 
+bool hasNegativeDimension(SmallVector<int64_t> values) {
+  return llvm::any_of(values, [](int64_t value) {
+    return !ShapedType::isDynamic(value) && value < 0;
+  });
+}
+
 /// Helper to sort `values` according to matching `keys`.
 template <typename K, typename V>
 static SmallVector<V>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..1c0a2e868475f24 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1102,6 +1102,30 @@ func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
 
 // -----
 
+func.func @no_fold_extract_slice_negative_offset(%arg0: tensor<8xf32>) -> tensor<?xf32> {
+  %c-1 = arith.constant -1 : index
+  %e = tensor.extract_slice %arg0[1] [%c-1] [1] : tensor<8xf32> to tensor<?xf32>
+  return %e : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_extract_slice_negative_offset
+// CHECK: tensor.extract_slice
+
+// -----
+
+func.func @no_fold_insert_slice_cast_inserter_negative_offset() -> tensor<?xf32> {
+  %c = arith.constant 0 : index
+  %const = tensor.empty(%c) : tensor<?xf32>
+  %insert_val = tensor.empty(%c) : tensor<?xf32>
+  %c-1 = arith.constant -1 : index
+  %inserted = tensor.insert_slice %insert_val into %const[0][%c-1][1] : tensor<?xf32> into tensor<?xf32>
+  return %inserted : tensor<?xf32>
+}
+// CHECK-LABEL: func @no_fold_insert_slice_cast_inserter_negative_offset
+// CHECK: %[[CAST:.*]] = tensor.cast
+// CHECK: tensor.insert_slice %[[CAST:.+]]
+
+// -----
+
 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
   %c0 = arith.constant dense<42> : tensor<2x8xi32>
   %0 = tensor.expand_shape %c0 [[0], [1, 2]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/72885


More information about the Mlir-commits mailing list