[Mlir-commits] [mlir] 1ee0d60 - [mlir][tensor] Remove incorrect parallel_insert_slice folder

Thomas Raoux llvmlistbot at llvm.org
Fri Aug 26 08:28:21 PDT 2022


Author: Thomas Raoux
Date: 2022-08-26T15:27:54Z
New Revision: 1ee0d60a9be5dcbe3234b81a1c93e6a206a88154

URL: https://github.com/llvm/llvm-project/commit/1ee0d60a9be5dcbe3234b81a1c93e6a206a88154
DIFF: https://github.com/llvm/llvm-project/commit/1ee0d60a9be5dcbe3234b81a1c93e6a206a88154.diff

LOG: [mlir][tensor] Remove incorrect parallel_insert_slice folder

parallel_insert_slice doesn't return a value therefore we shouldn't try
to fold the result. The insert folding don't apply to this op.
The current folding would cause pattern rewrite to not be able to
converge.

Differential Revision: https://reviews.llvm.org/D132668

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 4095d4e036a5a..9a0bbf690cf02 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1207,7 +1207,6 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
   ];
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 060e2fd9207aa..cd4c4b9f03a59 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1552,7 +1552,6 @@ LogicalResult InsertSliceOp::verify() {
 
 /// If we have two consecutive InsertSliceOp writing to the same slice, we
 /// can mutate the second InsertSliceOp's destination to the first one's.
-/// This works similarly when the second op is a ParallelInsertSliceOp.
 ///
 /// Example:
 ///
@@ -1568,9 +1567,8 @@ LogicalResult InsertSliceOp::verify() {
 /// ```
 ///
 /// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
-template <typename InsertOpTy>
-static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
-  auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
+static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
+  auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
 
   auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
   if (!prevInsertOp ||
@@ -1582,32 +1580,14 @@ static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
   return success();
 }
 
-/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
-/// type varies though so we wrap it in a FailureOr.
-///
-/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
-template <typename InsertOpTy>
-FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
-  if (insertOp.getSourceType().hasStaticShape() &&
-      insertOp.getDestType().hasStaticShape() &&
-      insertOp.getSourceType() == insertOp.getDestType() &&
-      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
-          insertOp, insertOp.getDestType())))
-    return static_cast<OpFoldResult>(insertOp.getSource());
-  if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
-    // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
-    // return OpFoldResult().
-    if (std::is_same<InsertOpTy, InsertSliceOp>::value)
-      return static_cast<OpFoldResult>(insertOp->getResult(0));
-    else
-      return OpFoldResult();
-  }
-  return failure();
-}
-
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
-  auto maybeOpFoldResult = foldInsertOp(*this, operands);
-  return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
+OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
+  if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
+      getSourceType() == getType() &&
+      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
+    return this->getSource();
+  if (succeeded(foldInsertAfterInsertSlice(*this)))
+    return getResult();
+  return OpFoldResult();
 }
 
 LogicalResult InsertSliceOp::reifyResultShapes(
@@ -2319,58 +2299,6 @@ LogicalResult ParallelInsertSliceOp::verify() {
   return produceSliceErrorMsg(result, *this, expectedType);
 }
 
-namespace {
-/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
-class ParallelInsertSliceOpConstantArgumentFolder final
-    : public OpRewritePattern<ParallelInsertSliceOp> {
-public:
-  using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // No constant operand, just return.
-    if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
-          return matchPattern(operand, matchConstantIndex());
-        }))
-      return failure();
-
-    // At least one of offsets/sizes/strides is a new constant.
-    // Form the new list of operands and constant attributes from the
-    // existing.
-    SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
-    SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
-    SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
-    canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
-    canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
-    canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
-
-    // Create the new op in canonical form.
-    auto sourceType =
-        tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
-            insertSliceOp.getSourceType().getRank(),
-            insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
-            mixedStrides);
-    Value toInsert = insertSliceOp.getSource();
-    if (sourceType != insertSliceOp.getSourceType()) {
-      OpBuilder::InsertionGuard g(rewriter);
-      rewriter.setInsertionPoint(insertSliceOp->getParentOp());
-      toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
-                                                 sourceType, toInsert);
-    }
-    rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
-        insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
-        mixedSizes, mixedStrides);
-    return success();
-  }
-};
-} // namespace
-
-LogicalResult
-ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
-                            SmallVectorImpl<OpFoldResult> &results) {
-  return foldInsertOp(*this, operands);
-}
-
 void ParallelInsertSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1eb1a5d7beca7..ad50ecb40db2f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1466,3 +1466,24 @@ func.func @canonicalize_parallel_insert_slice_indices(
   }
   return %2 : tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @dont_fold_parallel_insert_slice(
+//  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>, 
+//  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<1x5xf32>)
+func.func @dont_fold_parallel_insert_slice(
+    %arg0 : tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32>
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  //      CHECK: scf.foreach_thread () in () -> (tensor<1x5xf32>) {
+  // CHECK-NEXT:   scf.foreach_thread.perform_concurrently {
+  // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[arg1]][0, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor<1x5xf32>
+  %2 = scf.foreach_thread () in ()  -> (tensor<1x5xf32>) {
+    scf.foreach_thread.perform_concurrently {
+      tensor.parallel_insert_slice %arg0 into %arg1[%c0, %c0] [1, 5] [%c1, %c1] : tensor<1x5xf32> into tensor<1x5xf32>
+    }
+  }
+  return %2 : tensor<1x5xf32>
+}


        


More information about the Mlir-commits mailing list