[Mlir-commits] [mlir] 33d2a78 - [mlir][linalg] Add pattern to split reduction dimension in a linalg op
Thomas Raoux
llvmlistbot at llvm.org
Thu Mar 24 16:30:34 PDT 2022
Author: Thomas Raoux
Date: 2022-03-24T23:22:53Z
New Revision: 33d2a780a1397cb0203692019fc95f0784eff0fe
URL: https://github.com/llvm/llvm-project/commit/33d2a780a1397cb0203692019fc95f0784eff0fe
DIFF: https://github.com/llvm/llvm-project/commit/33d2a780a1397cb0203692019fc95f0784eff0fe.diff
LOG: [mlir][linalg] Add pattern to split reduction dimension in a linalg op
This transformation allow to break up a reduction dimension in a
parallel and a reduction dimension. This is followed by a separate
reduction op. This allows to generate tree reduction which is beneficial
on target allowing to take advantage parallelism.
Differential Revision: https://reviews.llvm.org/D122045
Added:
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/test/Dialect/Linalg/split_reduction.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index ca828abae43c0..a551f40141a00 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1090,6 +1090,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodName=*/"getRegionBuilder",
(ins),
[{ return ConcreteOp::getRegionBuilder(); }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if all the indexing maps are projected permutations.
+ Otherwise return false.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasOnlyProjectedPermutations",
+ (ins),
+ [{
+ return llvm::all_of($_op.getIndexingMaps(),
+ [](AffineMap map) { return map.isProjectedPermutation(); });
+ }]
>
];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8059ec1320968..3fb84bd53148a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1447,6 +1447,64 @@ class TilingPatterns<OpTy, OpTypes...> {
}
};
+/// Function signature to control reduction splitting. This returns a pair
+/// containing a ratio and a dimension index. The ratio is used to split the
+/// reduction dimension. The dimension index is used to control where the extra
+/// dimension is added to the intermediate tensor shape. If the ratio value is
+/// less or equal to 1 then nothing will be done.
+using ControlSplitReductionFn =
+ std::function<std::pair<int64_t, unsigned>(LinalgOp op)>;
+
+/// Patterns to apply `splitReduction` below.
+void populateSplitReductionPattern(
+ RewritePatternSet &patterns,
+ ControlSplitReductionFn controlSplitReductionFn,
+ LinalgTransformationFilter f = LinalgTransformationFilter());
+
+/// Apply transformation to split the single linalg op reduction into a parallel
+/// and reduction dimension. Then create a new linalg.generic op doing the rest
+/// of the reduction. Return the new linalg op with an extra parallel dimension
+/// or failure if the transformation didn't happen.
+/// Example:
+/// ```
+/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+/// affine_map<(d0) -> ()>],
+/// iterator_types = ["reduction"]}
+/// ins(%in : tensor<32xf32>)
+/// outs(%out : tensor<f32>) {
+/// ^bb0(%arg1: f32, %arg2: f32):
+/// %y = arith.addf %arg1, %arg2 : f32
+/// linalg.yield %y : f32
+/// } -> tensor<f32>
+/// ```
+/// To:
+/// ```
+/// %cst = arith.constant 0.000000e+00 : f32
+/// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
+/// %1 = linalg.init_tensor [4] : tensor<4xf32>
+/// %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>
+/// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+/// affine_map<(d0, d1) -> (d0)>],
+/// iterator_types = ["parallel", "reduction"]}
+/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
+/// ^bb0(%arg3: f32, %arg5: f32):
+/// %5 = arith.addf %arg3, %arg4 : f32
+/// linalg.yield %5 : f32
+/// } -> tensor<4xf32>
+/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+/// affine_map<(d0) -> ()>],
+/// iterator_types = ["reduction"]}
+/// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
+/// ^bb0(%arg3: f32, %arg4: f32):
+/// %5 = arith.addf %arg3, %arg4 : f32
+/// linalg.yield %5 : f32
+/// } -> tensor<f32>
+/// ```
+FailureOr<LinalgOp>
+splitReduction(PatternRewriter &b, LinalgOp op,
+ ControlSplitReductionFn controlSplitReductionFn,
+ LinalgTransformationFilter f);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7048a414aa829..77457dac3113e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
PadOpInterchange.cpp
Promotion.cpp
SparseTensorRewriting.cpp
+ SplitReduction.cpp
Tiling.cpp
Transforms.cpp
Vectorization.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
new file mode 100644
index 0000000000000..edfe80ca337aa
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -0,0 +1,234 @@
+//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements linalg transformation to break a reduction dimension
+// between a parallel and a reduction dimension.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Return the identity numeric value associated to the give op.
+static Optional<Attribute> getIdentity(Operation *op) {
+ // Builder only used as helper for attribute creation.
+ OpBuilder b(op->getContext());
+ Type resultType = op->getResult(0).getType();
+ if (auto floatType = resultType.dyn_cast<FloatType>()) {
+ const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
+ if (isa<arith::AddFOp>(op))
+ return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
+ if (isa<arith::MulFOp>(op))
+ return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
+ if (isa<arith::MaxFOp>(op))
+ return b.getFloatAttr(resultType,
+ llvm::APFloat::getLargest(semantic, true));
+ if (isa<arith::MinFOp>(op))
+ return b.getFloatAttr(resultType,
+ llvm::APFloat::getLargest(semantic, true));
+ return llvm::None;
+ }
+ if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
+ return b.getIntegerAttr(resultType, 0);
+ if (isa<arith::AndIOp>(op))
+ return b.getIntegerAttr(resultType, -1);
+ if (isa<arith::MaxSIOp>(op))
+ return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
+ if (isa<arith::MinSIOp>(op))
+ return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
+ if (isa<arith::MulIOp>(op))
+ return b.getIntegerAttr(resultType, 1);
+ return llvm::None;
+}
+
+FailureOr<LinalgOp>
+mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op,
+ ControlSplitReductionFn controlSplitReductionFn,
+ LinalgTransformationFilter filter) {
+ if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
+ op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
+ !op.hasOnlyProjectedPermutations())
+ return b.notifyMatchFailure(op, "precondition not met");
+ std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
+ int64_t ratio = control.first;
+ unsigned insertDimIndex = control.second;
+ if (ratio <= 1)
+ return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
+ SmallVector<unsigned> dims;
+ op.getReductionDims(dims);
+ assert(dims.size() == 1);
+ unsigned reductionDim = dims[0];
+ Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges();
+ if (!loopRanges)
+ return b.notifyMatchFailure(op, "Cannot analyze loops");
+ int64_t reductionDimSize = (*loopRanges)[reductionDim];
+ if (reductionDimSize == ShapedType::kDynamicSize ||
+ reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size())
+ return b.notifyMatchFailure(
+ op, "Reduction dimension not divisible by split ratio");
+ SmallVector<Operation *, 4> combinerOps;
+ if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
+ combinerOps.size() != 1)
+ return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
+ Operation *reductionOp = combinerOps[0];
+ Optional<Attribute> identity = getIdentity(reductionOp);
+ if (!identity)
+ return b.notifyMatchFailure(op, "Unknown identity value for the redution");
+
+ Location loc = op->getLoc();
+ SmallVector<Value> newInputs;
+ SmallVector<AffineMap> newMaps;
+ // Calculate the new shapes and indexing maps of the input operands.
+ for (OpOperand *operand : op.getInputOperands()) {
+ AffineMap map = op.getTiedIndexingMap(operand);
+ SmallVector<int64_t> newShape;
+ SmallVector<AffineExpr> exprs;
+ SmallVector<ReassociationIndices> reassociation;
+ unsigned index = 0;
+ for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
+ unsigned dim = map.getDimPosition(idx);
+ if (reductionDim == dim) {
+ newShape.push_back(ratio);
+ newShape.push_back(op.getShape(operand)[idx] / ratio);
+ reassociation.push_back({index++, index++});
+ exprs.push_back(b.getAffineDimExpr(insertDimIndex));
+ exprs.push_back(
+ b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+ continue;
+ }
+ newShape.push_back(op.getShape(operand)[idx]);
+ exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+ reassociation.push_back({index++});
+ }
+ newMaps.push_back(
+ AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
+ // If the shape is unchanged the input doesn't change.
+ if (newShape == op.getShape(operand)) {
+ newInputs.push_back(operand->get());
+ continue;
+ }
+ Type newType = RankedTensorType::get(
+ newShape,
+ operand->get().getType().cast<RankedTensorType>().getElementType());
+ Value newInput = b.create<tensor::ExpandShapeOp>(
+ loc, newType, operand->get(), reassociation);
+ newInputs.push_back(newInput);
+ }
+ // Calculate the new output map and shape, we insert the new dimension based
+ // on the index returned by `controlSplitReductionFn`.
+ SmallVector<int64_t> newOutputShape;
+ AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
+ ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
+ SmallVector<AffineExpr> outputExpr;
+ for (unsigned idx :
+ llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
+ if (idx == insertDimIndex) {
+ newOutputShape.push_back(ratio);
+ outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
+ continue;
+ }
+ unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
+ newOutputShape.push_back(oldShape[oldDim]);
+ unsigned dim = oldOutputMap.getDimPosition(oldDim);
+ outputExpr.push_back(
+ b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+ }
+ Value initTensor = b.create<linalg::InitTensorOp>(
+ loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
+ Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+ Value identityTensor =
+ b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
+ .getResult(0);
+
+ newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
+ op.getContext()));
+ SmallVector<StringRef> newIteratorTypes;
+ for (auto &it : llvm::enumerate(op.iterator_types())) {
+ if (insertDimIndex == it.index())
+ newIteratorTypes.push_back(getParallelIteratorTypeName());
+ newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
+ }
+ // Create the new op matching the original op with an extra parallel
+ // dimension.
+ GenericOp genericOp = b.create<GenericOp>(
+ loc, TypeRange({initTensor.getType()}), newInputs,
+ ValueRange({identityTensor}), newMaps, newIteratorTypes);
+ b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
+ genericOp.region().begin());
+
+ // Then create a new reduction that only reduce the newly added dimension from
+ // the previous op.
+ unsigned intermRank = newOutputShape.size();
+ AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
+ SmallVector<Value> outputOperands = op.getOutputOperands();
+ SmallVector<StringRef> reductionIteratorTypes;
+ SmallVector<AffineExpr> exprs;
+ for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
+ if (insertDimIndex == i) {
+ reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+ } else {
+ exprs.push_back(b.getAffineDimExpr(i));
+ reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+ }
+ }
+ AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
+ SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
+
+ auto reduction = b.create<GenericOp>(
+ loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
+ outputOperands, reductionMaps, reductionIteratorTypes,
+ [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
+ Operation *clonedReductionOp = b.clone(*reductionOp);
+ clonedReductionOp->setOperand(0, inputs[0]);
+ clonedReductionOp->setOperand(1, inputs[1]);
+ b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+ });
+ b.replaceOp(op, reduction.getResults());
+ filter.replaceLinalgTransformationFilter(b, genericOp);
+ filter.replaceLinalgTransformationFilter(b, reduction);
+ return cast<LinalgOp>(genericOp.getOperation());
+}
+
+namespace {
+
+struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
+ /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+ LinalgSplitReduction(MLIRContext *context,
+ ControlSplitReductionFn controlSplitReductionFn,
+ LinalgTransformationFilter f, PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+ controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) {
+ }
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return splitReduction(rewriter, op, controlSplitReductionFn, filter);
+ }
+
+private:
+ ControlSplitReductionFn controlSplitReductionFn;
+ LinalgTransformationFilter filter;
+};
+
+} // namespace
+
+void linalg::populateSplitReductionPattern(
+ RewritePatternSet &patterns,
+ ControlSplitReductionFn controlSplitReductionFn,
+ LinalgTransformationFilter f) {
+ patterns.add<LinalgSplitReduction>(patterns.getContext(),
+ controlSplitReductionFn, f);
+}
diff --git a/mlir/test/Dialect/Linalg/split_reduction.mlir b/mlir/test/Dialect/Linalg/split_reduction.mlir
new file mode 100644
index 0000000000000..c95510d43d12f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/split_reduction.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction -split-input-file | FileCheck %s
+
+func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: @matmul_split
+// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>
+// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32>
+// CHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
+// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<16x32x4xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
+// CHECK: arith.addf
+// CHECK: linalg.yield %{{.*}} : f32
+// CHECK: } -> tensor<16x32xf32>
+// CHECK: return %[[R]] : tensor<16x32xf32>
+
+// -----
+
+func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> {
+ %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]}
+ ins(%arg0, %arg1 : tensor<32xf32>, tensor<f32>)
+ outs(%out : tensor<f32>) {
+ ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
+ %40 = arith.subf %arg7, %arg8 : f32
+ %41 = math.exp %40 : f32
+ %42 = arith.mulf %41, %arg9 : f32
+ linalg.yield %42 : f32
+ } -> tensor<f32>
+ return %red : tensor<f32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
+//CHECK-LABEL: @generic_split_1d
+// CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
+// CHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
+// CHECK: %[[G:.*]] = linalg.generic
+// CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
+// CHECK: iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
+// CHECK: arith.subf
+// CHECK: math.exp
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<4xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
+// CHECK: arith.mulf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<f32>
+// CHECK: return %[[R]] : tensor<f32>
+
+// -----
+
+func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
+ -> tensor<5x2xf32>
+{
+ %0 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d1, d0)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d2, d0)>
+ ],
+ iterator_types = ["parallel", "reduction", "parallel"]
+ } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+ %3 = arith.addf %arg0, %arg1 : f32
+ %4 = arith.maxf %3, %arg2 : f32
+ linalg.yield %4 : f32
+ } -> tensor<5x2xf32>
+ return %0 : tensor<5x2xf32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func @generic_split_3d
+// CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
+// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
+// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
+// CHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
+// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
+// CHECK: arith.addf
+// CHECK: arith.maxf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<5x2x4xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
+// CHECK: arith.maxf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<5x2xf32>
+// CHECK: return %[[R]] : tensor<5x2xf32>
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 42068ab79ff9c..e6786b24f5939 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -111,6 +111,10 @@ struct TestLinalgTransforms
llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
"pad_tensor(subtensor)"),
llvm::cl::init(false)};
+ Option<bool> testSplitReduction{
+ *this, "test-split-reduction",
+ llvm::cl::desc("Test split reduction transformation"),
+ llvm::cl::init(false)};
ListOption<int64_t> peeledLoops{
*this, "peeled-loops",
llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -617,6 +621,20 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}
+static void applySplitReduction(FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ linalg::populateSplitReductionPattern(
+ patterns,
+ [](LinalgOp op) {
+ unsigned insertDimIndex = op.getNumLoops() - 1;
+ return std::make_pair(4, insertDimIndex);
+ },
+ LinalgTransformationFilter(
+ ArrayRef<StringAttr>{},
+ StringAttr::get(funcOp.getContext(), "SPLIT")));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnOperation() {
auto lambda = [&](void *) {
@@ -666,6 +684,8 @@ void TestLinalgTransforms::runOnOperation() {
if (testTileScalarizeDynamicDims)
return applyTilePattern(getOperation(), loopType, tileSizes,
/*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
+ if (testSplitReduction)
+ return applySplitReduction(getOperation());
}
namespace mlir {
More information about the Mlir-commits
mailing list