[Mlir-commits] [mlir] 33468a5 - [mlir][Tensor] Add support for insert_slice in FoldTensorSubsetOps
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Apr 14 09:34:39 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-14T09:34:11-07:00
New Revision: 33468a51dbf4c046ee6de71e3e8746bb69d15e33
URL: https://github.com/llvm/llvm-project/commit/33468a51dbf4c046ee6de71e3e8746bb69d15e33
DIFF: https://github.com/llvm/llvm-project/commit/33468a51dbf4c046ee6de71e3e8746bb69d15e33.diff
LOG: [mlir][Tensor] Add support for insert_slice in FoldTensorSubsetOps
Differential Revision: https://reviews.llvm.org/D148334
Added:
Modified:
mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
index 42156ac5de24d..7a55fe97c064e 100644
--- a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
#define MLIR_DIALECT_AFFINE_VIEWLIKEINTERFACEUTILS_H
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
@@ -22,7 +23,8 @@ class RewriterBase;
/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets
/// - Combined sizes = consumer_sizes
/// - Combined strides = producer_strides * consumer_strides
-// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
+// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
+// deprecate.
LogicalResult
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> producerOffsets,
@@ -38,7 +40,8 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use
/// when combining a `producer` slice op **into** a `consumer` slice op.
-// TODO: unify this API with resolveSourceIndicesOffsetsAndStrides or deprecate.
+// TODO: unify this API with resolveIndicesIntoOpWithOffsetsAndStrides or
+// deprecate.
LogicalResult
mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
OffsetSizeAndStrideOpInterface producer,
@@ -48,8 +51,8 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
SmallVector<OpFoldResult> &combinedSizes,
SmallVector<OpFoldResult> &combinedStrides);
-/// Given the 'indicesVals' of a load/store operation operating on an op with
-/// offsets and strides, return the combined indices.
+/// Given the 'consumerIndices' of a load/store operation operating on an op
+/// with offsets and strides, return the combined indices.
///
/// For example, using `memref.load` and `memref.subview` as an illustration:
///
@@ -64,13 +67,37 @@ mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc,
///
/// ```
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-/// memref<12x42xf32>
+/// memref<12x42xf32>å
/// ```
-void resolveSourceIndicesOffsetsAndStrides(
- RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
- ArrayRef<OpFoldResult> mixedStrides,
- const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
- SmallVectorImpl<Value> &sourceIndices);
+void resolveIndicesIntoOpWithOffsetsAndStrides(
+ RewriterBase &rewriter, Location loc,
+ ArrayRef<OpFoldResult> mixedSourceOffsets,
+ ArrayRef<OpFoldResult> mixedSourceStrides,
+ const llvm::SmallBitVector &rankReducedDims,
+ ArrayRef<OpFoldResult> consumerIndices,
+ SmallVectorImpl<Value> &resolvedIndices);
+
+inline void resolveIndicesIntoOpWithOffsetsAndStrides(
+ RewriterBase &rewriter, Location loc,
+ ArrayRef<OpFoldResult> mixedSourceOffsets,
+ ArrayRef<OpFoldResult> mixedSourceStrides,
+ const llvm::SmallBitVector &rankReducedDims, ValueRange consumerIndices,
+ SmallVectorImpl<Value> &resolvedIndices) {
+ return resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loc, mixedSourceOffsets, mixedSourceStrides, rankReducedDims,
+ getAsOpFoldResult(consumerIndices), resolvedIndices);
+}
+
+/// Given `sourceSizes`, `destSizes` and information about which dimensions are
+/// dropped by the source: `rankReducedSourceDims`, compute the resolved sizes
+/// that correspond to dest_op(source_op).
+/// In practice, this amounts to filtering by `rankReducedSourceDims` and taking
+/// from `sourceSizes` if a dimension is dropped, otherwise taking from
+/// `destSizes`.
+void resolveSizesIntoOpWithSizes(
+ ArrayRef<OpFoldResult> sourceSizes, ArrayRef<OpFoldResult> destSizes,
+ const llvm::SmallBitVector &rankReducedSourceDims,
+ SmallVectorImpl<OpFoldResult> &resolvedSizes);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 82f5ed96bfb96..6a7f542887fd8 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1963,13 +1963,13 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
}];
let builders = [
- // Build a SubViewOp with mixed static and dynamic entries and custom
- // result type. If the type passed is nullptr, it is inferred.
+ // Build a SubViewOp with mixed static and dynamic entries and inferred
+ // result type.
OpBuilder<(ins "Value":$source, "ArrayRef<OpFoldResult>":$offsets,
"ArrayRef<OpFoldResult>":$sizes, "ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
- // Build a SubViewOp with mixed static and dynamic entries and inferred
- // result type.
+ // Build a SubViewOp with mixed static and dynamic entries and custom
+ // result type. If the type passed is nullptr, it is inferred.
OpBuilder<(ins "MemRefType":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e628c9dfef647..8d0028d6d5343 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -823,17 +823,18 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
}];
let builders = [
- // Build a InsertSliceOp with mixed static and dynamic entries.
+ // Build a InsertSliceOp with mixed static and dynamic entries and inferred
+ // result type.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<OpFoldResult>":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
- // Build a InsertSliceOp with dynamic entries.
+ // Build a InsertSliceOp with dynamic entries and inferred result type.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
// Build an InsertSliceOp with mixed static and dynamic entries packed in
- // a Range vector.
+ // a Range vector and inferred result type.
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<Range>":$ranges,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
@@ -1450,6 +1451,10 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
/// Return the OpResult of the enclosing ForallOp that is
/// corresponding to this ParallelInsertSliceOp.
OpResult getTiedOpResult();
+
+ /// Return the dimensions of the dest that are omitted to insert a source
+ /// when the result is rank-extended.
+ llvm::SmallBitVector getDroppedDims();
}];
let builders = [
diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
index f53edcefe3c79..f47149d68fbbf 100644
--- a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp
@@ -77,32 +77,49 @@ LogicalResult mlir::mergeOffsetsSizesAndStrides(
combinedOffsets, combinedSizes, combinedStrides);
}
-void mlir::resolveSourceIndicesOffsetsAndStrides(
- RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> mixedOffsets,
- ArrayRef<OpFoldResult> mixedStrides,
- const llvm::SmallBitVector &rankReducedDims, ValueRange indicesVals,
- SmallVectorImpl<Value> &sourceIndices) {
+void mlir::resolveIndicesIntoOpWithOffsetsAndStrides(
+ RewriterBase &rewriter, Location loc,
+ ArrayRef<OpFoldResult> mixedSourceOffsets,
+ ArrayRef<OpFoldResult> mixedSourceStrides,
+ const llvm::SmallBitVector &rankReducedDims,
+ ArrayRef<OpFoldResult> consumerIndices,
+ SmallVectorImpl<Value> &resolvedIndices) {
OpFoldResult zero = rewriter.getIndexAttr(0);
// For each dimension that is rank-reduced, add a zero to the indices.
int64_t indicesDim = 0;
SmallVector<OpFoldResult> indices;
- for (auto dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
+ for (auto dim : llvm::seq<int64_t>(0, mixedSourceOffsets.size())) {
OpFoldResult ofr =
- (rankReducedDims.test(dim)) ? zero : indicesVals[indicesDim++];
+ (rankReducedDims.test(dim)) ? zero : consumerIndices[indicesDim++];
indices.push_back(ofr);
}
- sourceIndices.resize(indices.size());
- sourceIndices.clear();
+ resolvedIndices.resize(indices.size());
+ resolvedIndices.clear();
for (auto [offset, index, stride] :
- llvm::zip_equal(mixedOffsets, indices, mixedStrides)) {
+ llvm::zip_equal(mixedSourceOffsets, indices, mixedSourceStrides)) {
AffineExpr off, idx, str;
bindSymbols(rewriter.getContext(), off, idx, str);
OpFoldResult ofr = makeComposedFoldedAffineApply(
rewriter, loc, AffineMap::get(0, 3, off + idx * str),
{offset, index, stride});
- sourceIndices.push_back(
+ resolvedIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
}
+
+void mlir::resolveSizesIntoOpWithSizes(
+ ArrayRef<OpFoldResult> sourceSizes, ArrayRef<OpFoldResult> destSizes,
+ const llvm::SmallBitVector &rankReducedSourceDims,
+ SmallVectorImpl<OpFoldResult> &resolvedSizes) {
+ int64_t dim = 0;
+ int64_t srcRank = sourceSizes.size();
+ for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
+ if (rankReducedSourceDims[srcDim]) {
+ resolvedSizes.push_back(sourceSizes[srcDim]);
+ continue;
+ }
+ resolvedSizes.push_back(destSizes[dim++]);
+ }
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index ecf7bcbf997a4..a43184eb2dba9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -248,48 +248,38 @@ class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
LogicalResult matchAndRewrite(memref::SubViewOp subView,
PatternRewriter &rewriter) const override {
- Location loc = subView.getLoc();
auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
if (!srcSubView)
return failure();
- int64_t srcRank = srcSubView.getSourceType().getRank();
-
- // TODO: Only stride 1 is supported.
- for (auto s : {subView.getMixedStrides(), srcSubView.getMixedStrides()})
- if (!llvm::all_of(
- s, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }))
- return failure();
-
- // Get original offsets and sizes.
- SmallVector<OpFoldResult> offsets = subView.getMixedOffsets();
- SmallVector<OpFoldResult> srcOffsets = srcSubView.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = subView.getMixedSizes();
- SmallVector<OpFoldResult> srcSizes = srcSubView.getMixedSizes();
-
- // Compute new offsets and sizes.
- llvm::SmallBitVector srcReducedDims = srcSubView.getDroppedDims();
- SmallVector<OpFoldResult> newOffsets, newSizes;
- int64_t dim = 0;
- for (int64_t srcDim = 0; srcDim < srcRank; ++srcDim) {
- if (srcReducedDims[srcDim]) {
- // Dim is reduced in srcSubView.
- assert(isConstantIntValue(srcSizes[srcDim], 1) && "expected size 1");
- newOffsets.push_back(srcOffsets[srcDim]);
- newSizes.push_back(srcSizes[srcDim]);
- continue;
- }
- AffineExpr sym0, sym1;
- bindSymbols(subView.getContext(), sym0, sym1);
- newOffsets.push_back(makeComposedFoldedAffineApply(
- rewriter, loc, sym0 + sym1, {srcOffsets[srcDim], offsets[dim]}));
- newSizes.push_back(sizes[dim]);
- ++dim;
+
+ // TODO: relax unit stride assumption.
+ if (!subView.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(subView, "requires unit strides");
+ }
+ if (!srcSubView.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
}
+ // Resolve sizes according to dropped dims.
+ SmallVector<OpFoldResult> resolvedSizes;
+ llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
+ resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
+ subView.getMixedSizes(), srcDroppedDims,
+ resolvedSizes);
+
+ // Resolve offsets according to source offsets and strides.
+ SmallVector<Value> resolvedOffsets;
+ resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
+ srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
+ resolvedOffsets);
+
// Replace original op.
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
- subView, subView.getType(), srcSubView.getSource(), newOffsets,
- newSizes, srcSubView.getMixedStrides());
+ subView, subView.getType(), srcSubView.getSource(),
+ getAsOpFoldResult(resolvedOffsets), resolvedSizes,
+ srcSubView.getMixedStrides());
+
return success();
}
};
@@ -372,7 +362,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
- resolveSourceIndicesOffsetsAndStrides(
+ resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
sourceIndices);
@@ -492,7 +482,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value> sourceIndices;
- resolveSourceIndicesOffsetsAndStrides(
+ resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
sourceIndices);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fd59afbc44447..e70a2794701f8 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3086,6 +3086,10 @@ void ParallelInsertSliceOp::getCanonicalizationPatterns(
InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
}
+llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
+ return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 80ecb868dff6a..46bff2bb55cd5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
@@ -21,6 +22,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
namespace mlir {
namespace tensor {
@@ -98,7 +100,7 @@ LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
SmallVector<Value> indices(readOp.getIndices().begin(),
readOp.getIndices().end());
SmallVector<Value> sourceIndices;
- resolveSourceIndicesOffsetsAndStrides(
+ resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
indices, sourceIndices);
@@ -130,7 +132,7 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
SmallVector<Value> indices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<Value> sourceIndices;
- resolveSourceIndicesOffsetsAndStrides(
+ resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
sourceIndices);
@@ -145,9 +147,86 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
return success();
}
+template <typename OpTy>
+struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy insertSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceInsertSliceOp =
+ insertSliceOp.getSource()
+ .template getDefiningOp<tensor::InsertSliceOp>();
+ if (!sourceInsertSliceOp)
+ return failure();
+
+ // TODO: relax unit stride assumption where possible.
+ if (!insertSliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "requires unit strides");
+ }
+ if (!sourceInsertSliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(sourceInsertSliceOp,
+ "requires unit strides");
+ }
+
+ int64_t srcDim = 0;
+ llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
+ for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
+ if (droppedDims[d])
+ continue;
+ if (insertSliceOp.getMixedSizes()[d] !=
+ sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
+ return rewriter.notifyMatchFailure(
+ sourceInsertSliceOp,
+ "requires matching sizes to fold, otherwise a copy is needed");
+ }
+ }
+
+ // Resolve sizes according to dropped dims.
+ SmallVector<OpFoldResult> resolvedSizes;
+ // Note: the "insertSlice" case is symmetrical to the extract/subview case:
+ // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
+ // passed as the destination to the helper function.
+ resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
+ sourceInsertSliceOp.getMixedSizes(),
+ droppedDims, resolvedSizes);
+
+ // If we are inside an InParallel region, temporarily set the insertion
+ // point outside: only tensor.parallel_insert_slice ops are allowed in
+ // there.
+ if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
+ rewriter.setInsertionPoint(
+ insertSliceOp->template getParentOfType<scf::InParallelOp>());
+ }
+
+ // Resolve offsets according to source offsets and strides.
+ SmallVector<Value> resolvedOffsets;
+ // Note: the "insertSlice" case is symmetrical to the extract/subview case:
+ // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
+ // passed as the destination to the helper function.
+ resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
+ insertSliceOp.getMixedStrides(), droppedDims,
+ sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
+
+ // Reset the insertion point.
+ rewriter.setInsertionPoint(insertSliceOp);
+ // Replace original op.
+ rewriter.replaceOpWithNewOp<OpTy>(
+ insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
+ getAsOpFoldResult(resolvedOffsets), resolvedSizes,
+ insertSliceOp.getMixedStrides());
+
+ return success();
+ }
+};
+
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
patterns.add<TransferReadOfExtractSliceOpFolder,
- InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
+ InsertSliceOfTransferWriteOpFolder,
+ InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
+ InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
+ patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Pass registration
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index 93a0d77bc698f..f2e529b4cac95 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -fold-tensor-subset-ops -split-input-file --allow-unregistered-dialect %s | FileCheck %s
func.func @fold_vector_transfer_read_with_rank_reduced_extract_slice(
%arg0 : tensor<?x?x?xf32>,
@@ -260,3 +260,133 @@ func.func @insert_slice_of_transfer_write_rank_extending(%t1 : tensor<?x?x12xf32
%1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor<?x?x12xf32>
return %1 : tensor<?x?x12xf32>
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
+// CHECK-LABEL: func @insert_slice_of_insert_slice(
+// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<f32>
+// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<1x14xf32>
+// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
+// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+// CHECK: tensor.insert_slice %[[t]] into %[[r1]][4, %[[add]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
+func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1x1xf32>, %r1: tensor<1x14xf32>, %pos: index)
+ -> tensor<1x14xf32>
+{
+ %0 = tensor.insert_slice %t into %r0[1, 2] [1, 1] [1, 1]
+ : tensor<f32> into tensor<1x1xf32>
+ %1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1]
+ : tensor<1x1xf32> into tensor<1x14xf32>
+ return %1 : tensor<1x14xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_of_insert_slice(
+// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<f32>
+// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<1x14xf32>
+// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
+// CHECK: tensor.insert_slice %[[t]] into %[[r1]][5, %[[pos]]] [1, 1] [1, 1] : tensor<f32> into tensor<1x14xf32>
+func.func @insert_slice_of_insert_slice(%t: tensor<f32>, %r0: tensor<1xf32>, %r1: tensor<1x14xf32>, %pos: index)
+ -> tensor<1x14xf32>
+{
+ %0 = tensor.insert_slice %t into %r0[2] [1] [1]
+ : tensor<f32> into tensor<1xf32>
+ %1 = tensor.insert_slice %0 into %r1[3, %pos] [1, 1] [1, 1]
+ : tensor<1xf32> into tensor<1x14xf32>
+ return %1 : tensor<1x14xf32>
+}
+
+// -----
+
+// This test fails to fold because the size `4` and `%pos` do not match:
+// this requires a copy
+// CHECK-LABEL: func @fail_insert_slice_of_insert_slice(
+// CHECK: tensor.insert_slice
+// CHECK: tensor.insert_slice
+func.func @fail_insert_slice_of_insert_slice(
+ %t: tensor<4xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
+ -> tensor<?x?xf32>
+{
+ %0 = tensor.insert_slice %t into %r0[%pos] [4] [1]
+ : tensor<4xf32> into tensor<?xf32>
+ %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1]
+ : tensor<?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// Here the sizes are the same and the folding occurs properly.
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(
+// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<?xf32>
+// CHECK-SAME: %[[r0:[0-9a-z]*]]: tensor<?xf32>
+// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
+// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+// CHECK: tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor<?xf32> into tensor<?x?xf32>
+func.func @insert_slice_of_insert_slice_dynamic(
+ %t: tensor<?xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
+ -> tensor<?x?xf32>
+{
+ %0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1]
+ : tensor<?xf32> into tensor<?xf32>
+ %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1]
+ : tensor<?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// Here the sizes are the same and the folding occurs properly.
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @insert_slice_of_insert_slice_dynamic(
+// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<?xf32>
+// CHECK-SAME: %[[r0:[0-9a-z]*]]: tensor<?xf32>
+// CHECK-SAME: %[[r1:[0-9a-z]*]]: tensor<?x?xf32>
+// CHECK-SAME: %[[pos:[0-9a-z]*]]: index
+// CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[pos]]]
+// CHECK: tensor.insert_slice %[[t]] into %[[r1]][%[[add]], 423] [%[[pos]], 1] [1, 1] : tensor<?xf32> into tensor<?x?xf32>
+func.func @insert_slice_of_insert_slice_dynamic(
+ %t: tensor<?xf32>, %r0: tensor<?xf32>, %r1: tensor<?x?xf32>, %pos: index)
+ -> tensor<?x?xf32>
+{
+ %0 = tensor.insert_slice %t into %r0[%pos] [%pos] [1]
+ : tensor<?xf32> into tensor<?xf32>
+ %1 = tensor.insert_slice %0 into %r1[%pos, 423] [%pos, 1] [1, 1]
+ : tensor<?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @parallel_insert_slice_of_insert_slice_dynamic(
+// CHECK-SAME: %[[t:[0-9a-z]*]]: tensor<12x34xf32>
+// CHECK-SAME: %[[o0:[0-9a-z]*]]: index
+// CHECK-SAME: %[[o1:[0-9a-z]*]]: index
+// CHECK-SAME: %[[sz0:[0-9a-z]*]]: index
+// CHECK-SAME: %[[sz1:[0-9a-z]*]]: index
+func.func @parallel_insert_slice_of_insert_slice_dynamic(
+ %t: tensor<12x34xf32>, %o0: index, %o1: index, %sz0: index, %sz1: index)
+ -> tensor<12x34xf32>{
+
+ // CHECK: scf.forall {{.*}} shared_outs(%[[out:.*]] = %[[t]]
+ %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %t) -> (tensor<12x34xf32>) {
+ // CHECK: %[[tt:.*]] = "make_me_a_tensor"() : () -> tensor<?x?xf32>
+ %tt = "make_me_a_tensor"() : () -> tensor<?x?xf32>
+ %tt2 = "make_me_another_tensor"() : () -> tensor<?x?xf32>
+ %inserted_slice = tensor.insert_slice %tt into %tt2[%o1, 0] [%sz0, %sz1] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+ // CHECK: %[[add:.*]] = affine.apply #[[$map]]()[%[[o0]], %[[o1]]]
+ // CHECK: scf.forall.in_parallel
+ // CHECK: tensor.parallel_insert_slice %[[tt]] into %[[out]][%[[add]], %[[o1]]] [%[[sz0]], %[[sz1]]] [1, 1]
+ // CHECK-SAME: : tensor<?x?xf32> into tensor<12x34xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %inserted_slice into %arg2[%o0, %o1] [%sz0, %sz1] [1, 1]
+ : tensor<?x?xf32> into tensor<12x34xf32>
+ }
+ }
+ return %0: tensor<12x34xf32>
+}
More information about the Mlir-commits
mailing list