[Mlir-commits] [mlir] df5c981 - [mlir][Linalg] Add DropUnitDims support for tensor::ParallelInsertSliceOp.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jul 5 01:36:18 PDT 2022
Author: Nicolas Vasilache
Date: 2022-07-05T01:36:13-07:00
New Revision: df5c981be35a3267a50bf382d179e48bb2242e0f
URL: https://github.com/llvm/llvm-project/commit/df5c981be35a3267a50bf382d179e48bb2242e0f
DIFF: https://github.com/llvm/llvm-project/commit/df5c981be35a3267a50bf382d179e48bb2242e0f.diff
LOG: [mlir][Linalg] Add DropUnitDims support for tensor::ParallelInsertSliceOp.
ParallelInsertSlice behaves similarly to tensor::InsertSliceOp in its
rank-reducing properties.
This revision extends rank-reducing rewrite behavior and reuses most of the
existing implementation.
Differential Revision: https://reviews.llvm.org/D129091
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 2ef7e074127da..b35558e8a0a6b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -457,7 +457,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
namespace {
/// Convert `extract_slice` operations to rank-reduced versions.
-struct UseRankReducedExtractSliceOp
+struct RankReducedExtractSliceOp
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -487,27 +487,37 @@ struct UseRankReducedExtractSliceOp
};
/// Convert `insert_slice` operations to rank-reduced versions.
-struct UseRankReducedInsertSliceOp
- : public OpRewritePattern<tensor::InsertSliceOp> {
- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
+ using OpRewritePattern<InsertOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
- RankedTensorType sourceType = insertOp.getSourceType();
- SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
+ RankedTensorType sourceType = insertSliceOp.getSourceType();
+ SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();
- Location loc = insertOp.getLoc();
- auto reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
- loc, insertOp.getSource(), *reassociation);
- rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
- insertOp, reshapedSource, insertOp.getDest(),
- insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
- insertOp.getMixedStrides());
+ Location loc = insertSliceOp.getLoc();
+ tensor::CollapseShapeOp reshapedSource;
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ // The only
diff erence between InsertSliceOp and ParallelInsertSliceOp is
+ // the the insertion point is just before the ParallelCombiningOp in the
+ // parallel case.
+ if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+ reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
+ loc, insertSliceOp.getSource(), *reassociation);
+ }
+ rewriter.replaceOpWithNewOp<InsertOpTy>(
+ insertSliceOp, reshapedSource, insertSliceOp.getDest(),
+ insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+ insertSliceOp.getMixedStrides());
return success();
}
};
@@ -518,8 +528,9 @@ struct UseRankReducedInsertSliceOp
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.add<FoldUnitDimLoops, ReplaceUnitExtents,
- UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>(
+ patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
+ RankReducedInsertSliceOp<tensor::InsertSliceOp>,
+ RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
context);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 8c424ffbea268..a272dea3fa232 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -827,3 +827,23 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te
// CHECK-LABEL: func @sparse_case
// CHECK-NEXT: linalg.init_tensor
// CHECK-NEXT: linalg.generic
+
+// -----
+
+func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.init_tensor [4, 2] : tensor<4x2xf32>
+ %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) -> (tensor<4x2xf32>) {
+ %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ scf.foreach_thread.perform_concurrently {
+ // CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
+ // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
+ tensor.parallel_insert_slice %2 into %0[%arg0, %arg1] [1, 1] [1, 1] :
+ tensor<1x1xf32> into tensor<4x2xf32>
+ }
+ }
+ return %res: tensor<4x2xf32>
+}
More information about the Mlir-commits
mailing list