[Mlir-commits] [mlir] 83c65fb - [mlir][linalg] Expose pattern to collapse generic op dimensions
Thomas Raoux
llvmlistbot at llvm.org
Mon Oct 10 09:44:35 PDT 2022
Author: Thomas Raoux
Date: 2022-10-10T16:44:01Z
New Revision: 83c65fbc2842909444bfe0a74ed083d164381078
URL: https://github.com/llvm/llvm-project/commit/83c65fbc2842909444bfe0a74ed083d164381078
DIFF: https://github.com/llvm/llvm-project/commit/83c65fbc2842909444bfe0a74ed083d164381078.diff
LOG: [mlir][linalg] Expose pattern to collapse generic op dimensions
Add a pattern to be able to collapse dimensions in a linalg generic op.
Differential Revision: https://reviews.llvm.org/D135503
Added:
mlir/test/Dialect/Linalg/collapse-dim.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 62dcc8e786877..fb37c6f227728 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -76,6 +76,18 @@ void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
const ControlFusionFn &controlElementwiseOpFusion);
+/// Function type to control generic op dimension collapsing. It is expected
+/// to return an array of `ReassociationIndices` representing dimensions that
+/// should be merged.
+using GetCollapsableDimensionsFn =
+ std::function<SmallVector<ReassociationIndices>(linalg::GenericOp)>;
+
+/// Pattern to collapse dimensions in a linalg.generic op. This will collapse
+/// tensor operands when needed and expand back the result tensors.
+void populateCollapseDimensions(
+ RewritePatternSet &patterns,
+ const GetCollapsableDimensionsFn &controlCollapseDimensions);
+
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 45bc4a8a9bd23..05dce4c40272b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1367,7 +1367,7 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
/// Implementation of fusion with reshape operation by collapsing dimensions.
static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
- OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
+ PatternRewriter &rewriter) {
// Bail on trivial no-op cases.
if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
@@ -1510,7 +1510,7 @@ class FoldWithProducerReshapeOpByCollapsing
Optional<SmallVector<Value>> replacements =
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
- opOperand, rewriter);
+ rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
@@ -1525,6 +1525,37 @@ class FoldWithProducerReshapeOpByCollapsing
private:
ControlFusionFn controlFoldingReshapes;
};
+
+/// Pattern to collapse dimensions.
+class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
+public:
+ CollapseLinalgDimensions(MLIRContext *context,
+ GetCollapsableDimensionsFn collapseDimensions,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOp>(context, benefit),
+ controlCollapseDimension(std::move(collapseDimensions)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<ReassociationIndices> collapsableIterationDims =
+ controlCollapseDimension(genericOp);
+ if (collapsableIterationDims.empty())
+ return failure();
+
+ Optional<SmallVector<Value>> replacements = collapseGenericOpIterationDims(
+ genericOp, collapsableIterationDims, rewriter);
+ if (!replacements) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "failed to collpase dimensions");
+ }
+ rewriter.replaceOp(genericOp, *replacements);
+ return success();
+ }
+
+private:
+ GetCollapsableDimensionsFn controlCollapseDimension;
+};
+
} // namespace
//===---------------------------------------------------------------------===//
@@ -1743,6 +1774,13 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
RemoveOutsDependency>(context);
}
+void mlir::linalg::populateCollapseDimensions(
+ RewritePatternSet &patterns,
+ const GetCollapsableDimensionsFn &controlCollapseDimensions) {
+ patterns.add<CollapseLinalgDimensions>(patterns.getContext(),
+ controlCollapseDimensions);
+}
+
//===---------------------------------------------------------------------===//
// Passes
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
new file mode 100644
index 0000000000000..3587557a9ece0
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=collapse-dimensions-control=2,3 -split-input-file | FileCheck %s
+
+func.func @collapse_reduction(
+ %arg0: tensor<2x32x10x4096xf32>, %arg1: tensor<2x32xf32>) -> tensor<2x32xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ ins(%arg0 : tensor<2x32x10x4096xf32>) outs(%arg1 : tensor<2x32xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %1 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2x32xf32>
+ return %0 : tensor<2x32xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @collapse_reduction
+// CHECK: %[[T:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
+// CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%[[T]] : tensor<2x32x40960xf32>) outs(%{{.*}} : tensor<2x32xf32>) {
+// CHECK: } -> tensor<2x32xf32>
+
+// -----
+
+func.func @collapse_parallel(
+ %arg0: tensor<32x2x10x4096xf32>, %arg1: tensor<2x32x10x4096xf32>) -> tensor<2x32x10x4096xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<32x2x10x4096xf32>) outs(%arg1 : tensor<2x32x10x4096xf32>) {
+ ^bb0(%arg3: f32, %arg4: f32):
+ %1 = arith.addf %arg3, %arg4 : f32
+ linalg.yield %1 : f32
+ } -> tensor<2x32x10x4096xf32>
+ return %0 : tensor<2x32x10x4096xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func @collapse_parallel
+// CHECK-DAG: %[[S:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<32x2x10x4096xf32> into tensor<32x2x40960xf32>
+// CHECK-DAG: %[[D:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
+// CHECK: } -> tensor<2x32x40960xf32>
+// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 41e46d0d69206..0119516f272c0 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -99,6 +99,9 @@ struct TestLinalgElementwiseFusion
"fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
+ ListOption<int64_t> collapseDimensions{
+ *this, "collapse-dimensions-control",
+ llvm::cl::desc("Test controlling dimension collapse pattern")};
void runOnOperation() override {
MLIRContext *context = &this->getContext();
@@ -179,6 +182,20 @@ struct TestLinalgElementwiseFusion
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
+
+ if (!collapseDimensions.empty()) {
+ SmallVector<int64_t, 2> dims(collapseDimensions.begin(),
+ collapseDimensions.end());
+ linalg::GetCollapsableDimensionsFn collapseFn =
+ [&dims](linalg::GenericOp op) {
+ SmallVector<ReassociationIndices> reassociations;
+ reassociations.emplace_back(dims);
+ return reassociations;
+ };
+ RewritePatternSet patterns(context);
+ linalg::populateCollapseDimensions(patterns, collapseFn);
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+ }
}
};
More information about the Mlir-commits
mailing list