[Mlir-commits] [mlir] 33468a5 - [mlir][Tensor] Add support for insert_slice in FoldTensorSubsetOps

Nicolas Vasilache llvmlistbot at llvm.org
Fri Apr 14 09:34:39 PDT 2023


Author: Nicolas Vasilache
Date: 2023-04-14T09:34:11-07:00
New Revision: 33468a51dbf4c046ee6de71e3e8746bb69d15e33

URL: https://github.com/llvm/llvm-project/commit/33468a51dbf4c046ee6de71e3e8746bb69d15e33
DIFF: https://github.com/llvm/llvm-project/commit/33468a51dbf4c046ee6de71e3e8746bb69d15e33.diff

LOG: [mlir][Tensor] Add support for insert_slice in FoldTensorSubsetOps

Differential Revision: https://reviews.llvm.org/D148334

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
    mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
index 42156ac5de24d..7a55fe97c064e 100644
--- a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
 #define MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
 
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
@@ -22,7 +23,8 @@ class RewriterBase;
 /// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
 /// - Combined sizes = consumer_sizes
 /// - Combined strides = producer_strides * consumer_strides
-// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
+// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
+// deprecate.
 LogicalResult
 mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
                             ArrayRef<OpFoldResult> producerOffsets,
@@ -38,7 +40,8 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
 
 /// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
 /// when combining a `producer` slice op **into** a `consumer` slice op.
-// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
+// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
+// deprecate.
 LogicalResult
 mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
                             OffsetSizeAndStrideOpInterface producer,
@@ -48,8 +51,8 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
                             SmallVector<OpFoldResult> &combinedSizes,
                             SmallVector<OpFoldResult> &combinedStrides);
 
-/// Given the 'indicesVals' of a load/store operation operating on an op with
-/// offsets and strides, return the combined indices.
+/// Given the 'consumerIndices' of a load/store operation operating on an op
+/// with offsets and strides, return the combined indices.
 ///
 /// For example, using `memref.load` and `memref.subview` as an illustration:
 ///
@@ -64,13 +67,37 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
 ///
 /// ```
 ///    %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-///         memref<12x42xf32>
+///         memref<12x42xf32>å
 /// ```
-void resolveSourceIndicesOffsetsAndStrides(
-    RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
-    ArrayRef<OpFoldResult> mixedStrides,
-    const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
-    SmallVectorImpl<Value> &sourceIndices);
+void resolveIndicesIntoOpWithOffsetsAndStrides(
+    RewriterBase &rewriter, Location loc,
+    ArrayRef<OpFoldResult> mixedSourceOffsets,
+    ArrayRef<OpFoldResult> mixedSourceStrides,
+    const llvm::SmallBitVector &rankReducedDims,
+    ArrayRef<OpFoldResult> consumerIndices,
+    SmallVectorImpl<Value> &resolvedIndices);
+
+inline void resolveIndicesIntoOpWithOffsetsAndStrides(
+    RewriterBase &rewriter, Location loc,
+    ArrayRef<OpFoldResult> mixedSourceOffsets,
+    ArrayRef<OpFoldResult> mixedSourceStrides,
+    const llvm::SmallBitVector &rankReducedDims, ValueRange consumerIndices,
+    SmallVectorImpl<Value> &resolvedIndices) {
+  return resolveIndicesIntoOpWithOffsetsAndStrides(
+      rewriter, loc, mixedSourceOffsets, mixedSourceStrides, rankReducedDims,
+      getAsOpFoldResult(consumerIndices), resolvedIndices);
+}
+
+/// Given `sourceSizes`, `destSizes` and information about which dimensions are
+/// dropped by the source: `rankReducedSourceDims`, compute the resolved sizes
+/// that correspond to dest_op(source_op).
+/// In practice, this amounts to filtering by `rankReducedSourceDims` and taking
+/// from `sourceSizes` if a dimension is dropped, otherwise taking from
+/// `destSizes`.
+void resolveSizesIntoOpWithSizes(
+    ArrayRef<OpFoldResult> sourceSizes, ArrayRef<OpFoldResult> destSizes,
+    const llvm::SmallBitVector &rankReducedSourceDims,
+    SmallVectorImpl<OpFoldResult> &resolvedSizes);
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 82f5ed96bfb96..6a7f542887fd8 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1963,13 +1963,13 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
   }];
 
   let builders = [
-    // Build a SubViewOp with mixed static and dynamic entries and custom
-    // result type. If the type passed is nullptr, it is inferred.
+    // Build a SubViewOp with mixed static and dynamic entries and inferred
+    // result type.
     OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$offsets,
       "ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
-    // Build a SubViewOp with mixed static and dynamic entries and inferred
-    // result type.
+    // Build a SubViewOp with mixed static and dynamic entries and custom
+    // result type. If the type passed is nullptr, it is inferred.
     OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
       "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e628c9dfef647..8d0028d6d5343 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -823,17 +823,18 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
   }];
 
   let builders = [
-    // Build a InsertSliceOp with mixed static and dynamic entries.
+    // Build a InsertSliceOp with mixed static and dynamic entries and inferred
+    // result type.
     OpBuilder<(ins "Value":$source, "Value":$dest,
       "ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
       "ArrayRef<OpFoldResult>":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
-    // Build a InsertSliceOp with dynamic entries.
+    // Build a InsertSliceOp with dynamic entries and inferred result type.
     OpBuilder<(ins "Value":$source, "Value":$dest,
       "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build an InsertSliceOp with mixed static and dynamic entries packed in
-    // a Range vector.
+    // a Range vector and inferred result type.
     OpBuilder<(ins "Value":$source, "Value":$dest,
       "ArrayRef<Range>":$ranges,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
@@ -1450,6 +1451,10 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
     /// Return the OpResult of the enclosing ForallOp that is
     /// corresponding to this ParallelInsertSliceOp.
     OpResult getTiedOpResult();
+
+    /// Return the dimensions of the dest that are omitted to insert a source
+    /// when the result is rank-extended.
+    llvm::SmallBitVector getDroppedDims();
   }];
 
   let builders = [

diff  --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
index f53edcefe3c79..f47149d68fbbf 100644
--- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
@@ -77,32 +77,49 @@ LogicalResult mlir::mergeOffsetsSizesAndStrides(
       combinedOffsets, combinedSizes, combinedStrides);
 }
 
-void mlir::resolveSourceIndicesOffsetsAndStrides(
-    RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
-    ArrayRef<OpFoldResult> mixedStrides,
-    const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
-    SmallVectorImpl<Value> &sourceIndices) {
+void mlir::resolveIndicesIntoOpWithOffsetsAndStrides(
+    RewriterBase &rewriter, Location loc,
+    ArrayRef<OpFoldResult> mixedSourceOffsets,
+    ArrayRef<OpFoldResult> mixedSourceStrides,
+    const llvm::SmallBitVector &rankReducedDims,
+    ArrayRef<OpFoldResult> consumerIndices,
+    SmallVectorImpl<Value> &resolvedIndices) {
   OpFoldResult zero = rewriter.getIndexAttr(0);
 
   // For each dimension that is rank-reduced, add a zero to the indices.
   int64_t indicesDim = 0;
   SmallVector<OpFoldResult> indices;
-  for (auto dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
+  for (auto dim : llvm::seq<int64_t>(0, mixedSourceOffsets.size())) {
     OpFoldResult ofr =
-        (rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++];
+        (rankReducedDims.test(dim)) ? zero : consumerIndices[indicesDim++];
     indices.push_back(ofr);
   }
 
-  sourceIndices.resize(indices.size());
-  sourceIndices.clear();
+  resolvedIndices.resize(indices.size());
+  resolvedIndices.clear();
   for (auto [offset, index, stride] :
-       llvm::zip_equal(mixedOffsets, indices, mixedStrides)) {
+       llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) {
     AffineExpr off, idx, str;
     bindSymbols(rewriter.getContext(), off, idx, str);
     OpFoldResult ofr = makeComposedFoldedAffineApply(
         rewriter, loc, AffineMap::get(0, 3, off + idx * str),
         {offset, index, stride});
-    sourceIndices.push_back(
+    resolvedIndices.push_back(
         getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
   }
 }
+
+void mlir::resolveSizesIntoOpWithSizes(
+    ArrayRef<OpFoldResult> sourceSizes, ArrayRef<OpFoldResult> destSizes,
+    const llvm::SmallBitVector &rankReducedSourceDims,
+    SmallVectorImpl<OpFoldResult> &resolvedSizes) {
+  int64_t dim = 0;
+  int64_t srcRank = sourceSizes.size();
+  for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
+    if (rankReducedSourceDims[srcDim]) {
+      resolvedSizes.push_back(sourceSizes[srcDim]);
+      continue;
+    }
+    resolvedSizes.push_back(destSizes[dim++]);
+  }
+}

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index ecf7bcbf997a4..a43184eb2dba9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -248,48 +248,38 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
 
   LogicalResult matchAndRewrite(memref::SubViewOp subView,
                                 PatternRewriter &rewriter) const override {
-    Location loc = subView.getLoc();
     auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
     if (!srcSubView)
       return failure();
-    int64_t srcRank = srcSubView.getSourceType().getRank();
-
-    // TODO: Only stride 1 is supported.
-    for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()})
-      if (!llvm::all_of(
-              s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }))
-        return failure();
-
-    // Get original offsets and sizes.
-    SmallVector<OpFoldResult> offsets = subView.getMixedOffsets();
-    SmallVector<OpFoldResult> srcOffsets = srcSubView.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = subView.getMixedSizes();
-    SmallVector<OpFoldResult> srcSizes = srcSubView.getMixedSizes();
-
-    // Compute new offsets and sizes.
-    llvm::SmallBitVector srcReducedDims = srcSubView.getDroppedDims();
-    SmallVector<OpFoldResult> newOffsets, newSizes;
-    int64_t dim = 0;
-    for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
-      if (srcReducedDims[srcDim]) {
-        // Dim is reduced in srcSubView.
-        assert(isConstantIntValue(srcSizes[srcDim], 1) && "expected size 1");
-        newOffsets.push_back(srcOffsets[srcDim]);
-        newSizes.push_back(srcSizes[srcDim]);
-        continue;
-      }
-      AffineExpr sym0, sym1;
-      bindSymbols(subView.getContext(), sym0, sym1);
-      newOffsets.push_back(makeComposedFoldedAffineApply(
-          rewriter, loc, sym0 + sym1, {srcOffsets[srcDim], offsets[dim]}));
-      newSizes.push_back(sizes[dim]);
-      ++dim;
+
+    // TODO: relax unit stride assumption.
+    if (!subView.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(subView, "requires unit strides");
+    }
+    if (!srcSubView.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
     }
 
+    // Resolve sizes according to dropped dims.
+    SmallVector<OpFoldResult> resolvedSizes;
+    llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
+    resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
+                                subView.getMixedSizes(), srcDroppedDims,
+                                resolvedSizes);
+
+    // Resolve offsets according to source offsets and strides.
+    SmallVector<Value> resolvedOffsets;
+    resolveIndicesIntoOpWithOffsetsAndStrides(
+        rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
+        srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
+        resolvedOffsets);
+
     // Replace original op.
     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
-        subView, subView.getType(), srcSubView.getSource(), newOffsets,
-        newSizes, srcSubView.getMixedStrides());
+        subView, subView.getType(), srcSubView.getSource(),
+        getAsOpFoldResult(resolvedOffsets), resolvedSizes,
+        srcSubView.getMixedStrides());
+
     return success();
   }
 };
@@ -372,7 +362,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
-  resolveSourceIndicesOffsetsAndStrides(
+  resolveIndicesIntoOpWithOffsetsAndStrides(
       rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
       sourceIndices);
@@ -492,7 +482,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
-  resolveSourceIndicesOffsetsAndStrides(
+  resolveIndicesIntoOpWithOffsetsAndStrides(
       rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
       sourceIndices);

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fd59afbc44447..e70a2794701f8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3086,6 +3086,10 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
               InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
 }
 
+llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
+  return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 80ecb868dff6a..46bff2bb55cd5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
@@ -21,6 +22,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
 
 namespace mlir {
 namespace tensor {
@@ -98,7 +100,7 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
   SmallVector<Value> indices(readOp.getIndices().begin(),
                              readOp.getIndices().end());
   SmallVector<Value> sourceIndices;
-  resolveSourceIndicesOffsetsAndStrides(
+  resolveIndicesIntoOpWithOffsetsAndStrides(
       rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
       extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
       indices, sourceIndices);
@@ -130,7 +132,7 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   SmallVector<Value> indices(writeOp.getIndices().begin(),
                              writeOp.getIndices().end());
   SmallVector<Value> sourceIndices;
-  resolveSourceIndicesOffsetsAndStrides(
+  resolveIndicesIntoOpWithOffsetsAndStrides(
       rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
       insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
       sourceIndices);
@@ -145,9 +147,86 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   return success();
 }
 
+template <typename OpTy>
+struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy insertSliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceInsertSliceOp =
+        insertSliceOp.getSource()
+            .template getDefiningOp<tensor::InsertSliceOp>();
+    if (!sourceInsertSliceOp)
+      return failure();
+
+    // TODO: relax unit stride assumption where possible.
+    if (!insertSliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(insertSliceOp,
+                                         "requires unit strides");
+    }
+    if (!sourceInsertSliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(sourceInsertSliceOp,
+                                         "requires unit strides");
+    }
+
+    int64_t srcDim = 0;
+    llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
+    for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
+      if (droppedDims[d])
+        continue;
+      if (insertSliceOp.getMixedSizes()[d] !=
+          sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
+        return rewriter.notifyMatchFailure(
+            sourceInsertSliceOp,
+            "requires matching sizes to fold, otherwise a copy is needed");
+      }
+    }
+
+    // Resolve sizes according to dropped dims.
+    SmallVector<OpFoldResult> resolvedSizes;
+    // Note: the "insertSlice" case is symmetrical to the extract/subview case:
+    // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
+    // passed as the destination to the helper function.
+    resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
+                                sourceInsertSliceOp.getMixedSizes(),
+                                droppedDims, resolvedSizes);
+
+    // If we are inside an InParallel region, temporarily set the insertion
+    // point outside: only tensor.parallel_insert_slice ops are allowed in
+    // there.
+    if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
+      rewriter.setInsertionPoint(
+          insertSliceOp->template getParentOfType<scf::InParallelOp>());
+    }
+
+    // Resolve offsets according to source offsets and strides.
+    SmallVector<Value> resolvedOffsets;
+    // Note: the "insertSlice" case is symmetrical to the extract/subview case:
+    // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
+    // passed as the destination to the helper function.
+    resolveIndicesIntoOpWithOffsetsAndStrides(
+        rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
+        insertSliceOp.getMixedStrides(), droppedDims,
+        sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
+
+    // Reset the insertion point.
+    rewriter.setInsertionPoint(insertSliceOp);
+    // Replace original op.
+    rewriter.replaceOpWithNewOp<OpTy>(
+        insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
+        getAsOpFoldResult(resolvedOffsets), resolvedSizes,
+        insertSliceOp.getMixedStrides());
+
+    return success();
+  }
+};
+
 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
   patterns.add<TransferReadOfExtractSliceOpFolder,
-               InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
+               InsertSliceOfTransferWriteOpFolder,
+               InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
+               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
+      patterns.getContext());
 }
 //===----------------------------------------------------------------------===//
 // Pass registration

diff  --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index 93a0d77bc698f..f2e529b4cac95 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file --allow-unregistered-dialect %s | FileCheck %s
 
 func.func @fold_vector_transfer_read_with_rank_reduced_extract_slice(
     %arg0 : tensor<?x?x?xf32>,
@@ -260,3 +260,133 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
   %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
   return %1 : tensor<?x?x12xf32>
 }
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
+// CHECK-LABEL: func @insert_slice_of_insert_slice(
+//  CHECK-SAME:     %[[t:[0-9a-z]*]]: tensor<f32>
+//  CHECK-SAME:     %[[r1:[0-9a-z]*]]: tensor<1x14xf32>
+//  CHECK-SAME:     %[[pos:[0-9a-z]*]]: index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+//       CHECK:   tensor.insert_slice %[[t]] into %[[r1]][4, %[[add]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
+func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1x1xf32>, %r1: tensor<1x14xf32>, %pos: index)
+    -> tensor<1x14xf32> 
+{
+  %0 = tensor.insert_slice %t into %r0[1, 2] [1, 1] [1, 1] 
+    : tensor<f32> into tensor<1x1xf32>
+  %1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1] 
+    : tensor<1x1xf32> into tensor<1x14xf32>
+  return %1 : tensor<1x14xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_insert_slice(
+//  CHECK-SAME:     %[[t:[0-9a-z]*]]: tensor<f32>
+//  CHECK-SAME:     %[[r1:[0-9a-z]*]]: tensor<1x14xf32>
+//  CHECK-SAME:     %[[pos:[0-9a-z]*]]: index
+//       CHECK:   tensor.insert_slice %[[t]] into %[[r1]][5, %[[pos]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
+func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1xf32>, %r1: tensor<1x14xf32>, %pos: index)
+    -> tensor<1x14xf32> 
+{
+  %0 = tensor.insert_slice %t into %r0[2] [1] [1] 
+    : tensor<f32> into tensor<1xf32>
+  %1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1] 
+    : tensor<1xf32> into tensor<1x14xf32>
+  return %1 : tensor<1x14xf32>
+}
+
+// -----
+
+// This test fails to fold because the size `4` and `%pos` do not match: 
+// this requires a copy
+// CHECK-LABEL: func @fail_insert_slice_of_insert_slice(
+//       CHECK:   tensor.insert_slice
+//       CHECK:   tensor.insert_slice
+func.func @fail_insert_slice_of_insert_slice(
+  %t: tensor<4xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
+    -> tensor<?x?xf32> 
+{
+  %0 = tensor.insert_slice %t into %r0[%pos] [4] [1] 
+    : tensor<4xf32> into tensor<?xf32>
+  %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1] 
+    : tensor<?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// Here the sizes are the same and the folding occurs properly.
+//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(
+//  CHECK-SAME:     %[[t:[0-9a-z]*]]: tensor<?xf32>
+//  CHECK-SAME:     %[[r0:[0-9a-z]*]]: tensor<?xf32>
+//  CHECK-SAME:     %[[r1:[0-9a-z]*]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[pos:[0-9a-z]*]]: index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+//       CHECK:   tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor<?xf32> into tensor<?x?xf32>
+func.func @insert_slice_of_insert_slice_dynamic(
+  %t: tensor<?xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
+    -> tensor<?x?xf32> 
+{
+  %0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1] 
+    : tensor<?xf32> into tensor<?xf32>
+  %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1] 
+    : tensor<?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// Here the sizes are the same and the folding occurs properly.
+//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(
+//  CHECK-SAME:     %[[t:[0-9a-z]*]]: tensor<?xf32>
+//  CHECK-SAME:     %[[r0:[0-9a-z]*]]: tensor<?xf32>
+//  CHECK-SAME:     %[[r1:[0-9a-z]*]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[pos:[0-9a-z]*]]: index
+//       CHECK:   %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+//       CHECK:   tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor<?xf32> into tensor<?x?xf32>
+func.func @insert_slice_of_insert_slice_dynamic(
+  %t: tensor<?xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
+    -> tensor<?x?xf32> 
+{
+  %0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1] 
+    : tensor<?xf32> into tensor<?xf32>
+  %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1] 
+    : tensor<?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @parallel_insert_slice_of_insert_slice_dynamic(
+//  CHECK-SAME:   %[[t:[0-9a-z]*]]: tensor<12x34xf32>
+//  CHECK-SAME:   %[[o0:[0-9a-z]*]]: index
+//  CHECK-SAME:   %[[o1:[0-9a-z]*]]: index
+//  CHECK-SAME:   %[[sz0:[0-9a-z]*]]: index
+//  CHECK-SAME:   %[[sz1:[0-9a-z]*]]: index
+func.func @parallel_insert_slice_of_insert_slice_dynamic(
+    %t: tensor<12x34xf32>, %o0: index, %o1: index, %sz0: index, %sz1: index) 
+      -> tensor<12x34xf32>{
+
+  // CHECK: scf.forall {{.*}} shared_outs(%[[out:.*]] = %[[t]]
+  %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %t) -> (tensor<12x34xf32>) {
+    // CHECK: %[[tt:.*]] = "make_me_a_tensor"() : () -> tensor<?x?xf32>
+    %tt = "make_me_a_tensor"() : () -> tensor<?x?xf32>
+    %tt2 = "make_me_another_tensor"() : () -> tensor<?x?xf32>
+    %inserted_slice = tensor.insert_slice %tt into %tt2[%o1, 0] [%sz0, %sz1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+    //      CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[o0]], %[[o1]]]
+    //      CHECK: scf.forall.in_parallel
+    //      CHECK:   tensor.parallel_insert_slice %[[tt]] into %[[out]][%[[add]], %[[o1]]] [%[[sz0]], %[[sz1]]] [1, 1]
+    // CHECK-SAME:     : tensor<?x?xf32> into tensor<12x34xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %inserted_slice into %arg2[%o0, %o1] [%sz0, %sz1] [1, 1]
+        : tensor<?x?xf32> into tensor<12x34xf32>
+    }
+  }
+  return %0: tensor<12x34xf32>
+}


        


More information about the Mlir-commits mailing list