[Mlir-commits] [mlir] 33d2a78 - [mlir][linalg] Add pattern to split reduction dimension in a linalg op

Thomas Raoux llvmlistbot at llvm.org
Thu Mar 24 16:30:34 PDT 2022


Author: Thomas Raoux
Date: 2022-03-24T23:22:53Z
New Revision: 33d2a780a1397cb0203692019fc95f0784eff0fe

URL: https://github.com/llvm/llvm-project/commit/33d2a780a1397cb0203692019fc95f0784eff0fe
DIFF: https://github.com/llvm/llvm-project/commit/33d2a780a1397cb0203692019fc95f0784eff0fe.diff

LOG: [mlir][linalg] Add pattern to split reduction dimension in a linalg op

This transformation allow to break up a reduction dimension in a
parallel and a reduction dimension. This is followed by a separate
reduction op. This allows to generate tree reduction which is beneficial
on target allowing to take advantage parallelism.

Differential Revision: https://reviews.llvm.org/D122045

Added: 
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/test/Dialect/Linalg/split_reduction.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index ca828abae43c0..a551f40141a00 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1090,6 +1090,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if all the indexing maps are projected permutations.
+        Otherwise return false.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasOnlyProjectedPermutations",
+      (ins),
+      [{
+        return llvm::all_of($_op.getIndexingMaps(),
+                            [](AffineMap map) { return map.isProjectedPermutation(); });
+      }]
     >
   ];
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8059ec1320968..3fb84bd53148a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1447,6 +1447,64 @@ class TilingPatterns<OpTy, OpTypes...> {
   }
 };
 
+/// Function signature to control reduction splitting. This returns a pair
+/// containing a ratio and a dimension index. The ratio is used to split the
+/// reduction dimension. The dimension index is used to control where the extra
+/// dimension is added to the intermediate tensor shape. If the ratio value is
+/// less or equal to 1 then nothing will be done.
+using ControlSplitReductionFn =
+    std::function<std::pair<int64_t, unsigned>(LinalgOp op)>;
+
+/// Patterns to apply `splitReduction` below.
+void populateSplitReductionPattern(
+    RewritePatternSet &patterns,
+    ControlSplitReductionFn controlSplitReductionFn,
+    LinalgTransformationFilter f = LinalgTransformationFilter());
+
+/// Apply transformation to split the single linalg op reduction into a parallel
+/// and reduction dimension. Then create a new linalg.generic op doing the rest
+/// of the reduction. Return the new linalg op with an extra parallel dimension
+/// or failure if the transformation didn't happen.
+/// Example:
+/// ```
+///  %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+///                                        affine_map<(d0) -> ()>],
+///       iterator_types = ["reduction"]}
+///  ins(%in : tensor<32xf32>)
+///  outs(%out : tensor<f32>) {
+///  ^bb0(%arg1: f32, %arg2: f32):
+///    %y = arith.addf %arg1, %arg2 : f32
+///    linalg.yield %y : f32
+///  } -> tensor<f32>
+/// ```
+/// To:
+/// ```
+///  %cst = arith.constant 0.000000e+00 : f32
+///  %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
+///  %1 = linalg.init_tensor [4] : tensor<4xf32>
+///  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>
+///  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+///                                        affine_map<(d0, d1) -> (d0)>],
+///    iterator_types = ["parallel", "reduction"]}
+///    ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
+///    ^bb0(%arg3: f32, %arg5: f32):
+///    %5 = arith.addf %arg3, %arg4 : f32
+///    linalg.yield %5 : f32
+///  } -> tensor<4xf32>
+/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+///                                       affine_map<(d0) -> ()>],
+///   iterator_types = ["reduction"]}
+///   ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
+///   ^bb0(%arg3: f32, %arg4: f32):
+///   %5 = arith.addf %arg3, %arg4 : f32
+///   linalg.yield %5 : f32
+/// } -> tensor<f32>
+/// ```
+FailureOr<LinalgOp>
+splitReduction(PatternRewriter &b, LinalgOp op,
+               ControlSplitReductionFn controlSplitReductionFn,
+               LinalgTransformationFilter f);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7048a414aa829..77457dac3113e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   PadOpInterchange.cpp
   Promotion.cpp
   SparseTensorRewriting.cpp
+  SplitReduction.cpp
   Tiling.cpp
   Transforms.cpp
   Vectorization.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
new file mode 100644
index 0000000000000..edfe80ca337aa
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -0,0 +1,234 @@
+//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements linalg transformation to break a reduction dimension
+// between a parallel and a reduction dimension.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Return the identity numeric value associated to the give op.
+static Optional<Attribute> getIdentity(Operation *op) {
+  // Builder only used as helper for attribute creation.
+  OpBuilder b(op->getContext());
+  Type resultType = op->getResult(0).getType();
+  if (auto floatType = resultType.dyn_cast<FloatType>()) {
+    const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
+    if (isa<arith::AddFOp>(op))
+      return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
+    if (isa<arith::MulFOp>(op))
+      return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
+    if (isa<arith::MaxFOp>(op))
+      return b.getFloatAttr(resultType,
+                            llvm::APFloat::getLargest(semantic, true));
+    if (isa<arith::MinFOp>(op))
+      return b.getFloatAttr(resultType,
+                            llvm::APFloat::getLargest(semantic, true));
+    return llvm::None;
+  }
+  if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
+    return b.getIntegerAttr(resultType, 0);
+  if (isa<arith::AndIOp>(op))
+    return b.getIntegerAttr(resultType, -1);
+  if (isa<arith::MaxSIOp>(op))
+    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
+  if (isa<arith::MinSIOp>(op))
+    return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
+  if (isa<arith::MulIOp>(op))
+    return b.getIntegerAttr(resultType, 1);
+  return llvm::None;
+}
+
+FailureOr<LinalgOp>
+mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op,
+                             ControlSplitReductionFn controlSplitReductionFn,
+                             LinalgTransformationFilter filter) {
+  if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
+      op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
+      !op.hasOnlyProjectedPermutations())
+    return b.notifyMatchFailure(op, "precondition not met");
+  std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
+  int64_t ratio = control.first;
+  unsigned insertDimIndex = control.second;
+  if (ratio <= 1)
+    return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
+  SmallVector<unsigned> dims;
+  op.getReductionDims(dims);
+  assert(dims.size() == 1);
+  unsigned reductionDim = dims[0];
+  Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges();
+  if (!loopRanges)
+    return b.notifyMatchFailure(op, "Cannot analyze loops");
+  int64_t reductionDimSize = (*loopRanges)[reductionDim];
+  if (reductionDimSize == ShapedType::kDynamicSize ||
+      reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size())
+    return b.notifyMatchFailure(
+        op, "Reduction dimension not divisible by split ratio");
+  SmallVector<Operation *, 4> combinerOps;
+  if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
+      combinerOps.size() != 1)
+    return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
+  Operation *reductionOp = combinerOps[0];
+  Optional<Attribute> identity = getIdentity(reductionOp);
+  if (!identity)
+    return b.notifyMatchFailure(op, "Unknown identity value for the redution");
+
+  Location loc = op->getLoc();
+  SmallVector<Value> newInputs;
+  SmallVector<AffineMap> newMaps;
+  // Calculate the new shapes and indexing maps of the input operands.
+  for (OpOperand *operand : op.getInputOperands()) {
+    AffineMap map = op.getTiedIndexingMap(operand);
+    SmallVector<int64_t> newShape;
+    SmallVector<AffineExpr> exprs;
+    SmallVector<ReassociationIndices> reassociation;
+    unsigned index = 0;
+    for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
+      unsigned dim = map.getDimPosition(idx);
+      if (reductionDim == dim) {
+        newShape.push_back(ratio);
+        newShape.push_back(op.getShape(operand)[idx] / ratio);
+        reassociation.push_back({index++, index++});
+        exprs.push_back(b.getAffineDimExpr(insertDimIndex));
+        exprs.push_back(
+            b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+        continue;
+      }
+      newShape.push_back(op.getShape(operand)[idx]);
+      exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+      reassociation.push_back({index++});
+    }
+    newMaps.push_back(
+        AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
+    // If the shape is unchanged the input doesn't change.
+    if (newShape == op.getShape(operand)) {
+      newInputs.push_back(operand->get());
+      continue;
+    }
+    Type newType = RankedTensorType::get(
+        newShape,
+        operand->get().getType().cast<RankedTensorType>().getElementType());
+    Value newInput = b.create<tensor::ExpandShapeOp>(
+        loc, newType, operand->get(), reassociation);
+    newInputs.push_back(newInput);
+  }
+  // Calculate the new output map and shape, we insert the new dimension based
+  // on the index returned by `controlSplitReductionFn`.
+  SmallVector<int64_t> newOutputShape;
+  AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
+  ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
+  SmallVector<AffineExpr> outputExpr;
+  for (unsigned idx :
+       llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
+    if (idx == insertDimIndex) {
+      newOutputShape.push_back(ratio);
+      outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
+      continue;
+    }
+    unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
+    newOutputShape.push_back(oldShape[oldDim]);
+    unsigned dim = oldOutputMap.getDimPosition(oldDim);
+    outputExpr.push_back(
+        b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+  }
+  Value initTensor = b.create<linalg::InitTensorOp>(
+      loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
+  Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+  Value identityTensor =
+      b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
+          .getResult(0);
+
+  newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
+                                   op.getContext()));
+  SmallVector<StringRef> newIteratorTypes;
+  for (auto &it : llvm::enumerate(op.iterator_types())) {
+    if (insertDimIndex == it.index())
+      newIteratorTypes.push_back(getParallelIteratorTypeName());
+    newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
+  }
+  // Create the new op matching the original op with an extra parallel
+  // dimension.
+  GenericOp genericOp = b.create<GenericOp>(
+      loc, TypeRange({initTensor.getType()}), newInputs,
+      ValueRange({identityTensor}), newMaps, newIteratorTypes);
+  b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
+                       genericOp.region().begin());
+
+  // Then create a new reduction that only reduce the newly added dimension from
+  // the previous op.
+  unsigned intermRank = newOutputShape.size();
+  AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
+  SmallVector<Value> outputOperands = op.getOutputOperands();
+  SmallVector<StringRef> reductionIteratorTypes;
+  SmallVector<AffineExpr> exprs;
+  for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
+    if (insertDimIndex == i) {
+      reductionIteratorTypes.push_back(getReductionIteratorTypeName());
+    } else {
+      exprs.push_back(b.getAffineDimExpr(i));
+      reductionIteratorTypes.push_back(getParallelIteratorTypeName());
+    }
+  }
+  AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
+  SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
+
+  auto reduction = b.create<GenericOp>(
+      loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
+      outputOperands, reductionMaps, reductionIteratorTypes,
+      [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
+        Operation *clonedReductionOp = b.clone(*reductionOp);
+        clonedReductionOp->setOperand(0, inputs[0]);
+        clonedReductionOp->setOperand(1, inputs[1]);
+        b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+      });
+  b.replaceOp(op, reduction.getResults());
+  filter.replaceLinalgTransformationFilter(b, genericOp);
+  filter.replaceLinalgTransformationFilter(b, reduction);
+  return cast<LinalgOp>(genericOp.getOperation());
+}
+
+namespace {
+
+struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
+  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+  LinalgSplitReduction(MLIRContext *context,
+                       ControlSplitReductionFn controlSplitReductionFn,
+                       LinalgTransformationFilter f, PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+        controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) {
+  }
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return splitReduction(rewriter, op, controlSplitReductionFn, filter);
+  }
+
+private:
+  ControlSplitReductionFn controlSplitReductionFn;
+  LinalgTransformationFilter filter;
+};
+
+} // namespace
+
+void linalg::populateSplitReductionPattern(
+    RewritePatternSet &patterns,
+    ControlSplitReductionFn controlSplitReductionFn,
+    LinalgTransformationFilter f) {
+  patterns.add<LinalgSplitReduction>(patterns.getContext(),
+                                     controlSplitReductionFn, f);
+}

diff  --git a/mlir/test/Dialect/Linalg/split_reduction.mlir b/mlir/test/Dialect/Linalg/split_reduction.mlir
new file mode 100644
index 0000000000000..c95510d43d12f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/split_reduction.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction  -split-input-file  | FileCheck %s
+
+func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+                    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+  return %0: tensor<16x32xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
+//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//  CHECK-LABEL: @matmul_split
+//  CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32>
+//  CHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32>
+//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
+//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:   , iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
+//      CHECK:   arith.mulf
+//      CHECK:   arith.addf
+//      CHECK:   linalg.yield
+//      CHECK: } -> tensor<16x32x4xf32>
+//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
+// CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
+//      CHECK:   arith.addf
+//      CHECK:   linalg.yield %{{.*}} : f32
+//      CHECK: } -> tensor<16x32xf32>
+//      CHECK: return %[[R]] : tensor<16x32xf32>
+
+// -----
+
+func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+                                          affine_map<(d0) -> ()>,
+                                          affine_map<(d0) -> ()>],
+   iterator_types = ["reduction"]}
+   ins(%arg0, %arg1 : tensor<32xf32>, tensor<f32>)
+   outs(%out : tensor<f32>) {
+    ^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
+      %40 = arith.subf %arg7, %arg8 : f32
+      %41 = math.exp %40 : f32
+      %42 = arith.mulf %41, %arg9 : f32
+      linalg.yield %42 : f32
+    } -> tensor<f32>
+  return %red : tensor<f32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
+//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
+//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
+//CHECK-LABEL: @generic_split_1d
+//      CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
+//      CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
+//      CHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32>
+//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
+//      CHECK: %[[G:.*]] = linalg.generic
+//      CHECK:   {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
+//      CHECK:   iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
+//      CHECK:   arith.subf
+//      CHECK:   math.exp
+//      CHECK:   arith.mulf
+//      CHECK:   linalg.yield
+//      CHECK: } -> tensor<4xf32>
+//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
+//      CHECK:   arith.mulf
+//      CHECK:   linalg.yield
+//      CHECK: } -> tensor<f32>
+//      CHECK: return %[[R]] : tensor<f32>
+
+// -----
+
+func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
+  -> tensor<5x2xf32>
+{
+  %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1, d2) -> (d1, d0)>,
+        affine_map<(d0, d1, d2) -> (d2, d1)>,
+        affine_map<(d0, d1, d2) -> (d2, d0)>
+      ],
+      iterator_types = ["parallel", "reduction", "parallel"]
+    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
+    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+      %3 = arith.addf %arg0, %arg1 : f32
+      %4 = arith.maxf %3, %arg2 : f32
+      linalg.yield %4 : f32
+    } -> tensor<5x2xf32>
+  return %0 : tensor<5x2xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
+//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
+//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:  func @generic_split_3d
+//      CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
+//      CHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32>
+//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
+//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
+// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
+//      CHECK:   arith.addf
+//      CHECK:   arith.maxf
+//      CHECK:   linalg.yield
+//      CHECK: } -> tensor<5x2x4xf32>
+//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
+//      CHECK:   arith.maxf
+//      CHECK:   linalg.yield
+//      CHECK:  } -> tensor<5x2xf32>
+//      CHECK: return %[[R]] : tensor<5x2xf32>

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 42068ab79ff9c..e6786b24f5939 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -111,6 +111,10 @@ struct TestLinalgTransforms
       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
                      "pad_tensor(subtensor)"),
       llvm::cl::init(false)};
+  Option<bool> testSplitReduction{
+      *this, "test-split-reduction",
+      llvm::cl::desc("Test split reduction transformation"),
+      llvm::cl::init(false)};
   ListOption<int64_t> peeledLoops{
       *this, "peeled-loops",
       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -617,6 +621,20 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
 }
 
+static void applySplitReduction(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  linalg::populateSplitReductionPattern(
+      patterns,
+      [](LinalgOp op) {
+        unsigned insertDimIndex = op.getNumLoops() - 1;
+        return std::make_pair(4, insertDimIndex);
+      },
+      LinalgTransformationFilter(
+          ArrayRef<StringAttr>{},
+          StringAttr::get(funcOp.getContext(), "SPLIT")));
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnOperation() {
   auto lambda = [&](void *) {
@@ -666,6 +684,8 @@ void TestLinalgTransforms::runOnOperation() {
   if (testTileScalarizeDynamicDims)
     return applyTilePattern(getOperation(), loopType, tileSizes,
                             /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true);
+  if (testSplitReduction)
+    return applySplitReduction(getOperation());
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list