[Mlir-commits] [mlir] d571639 - [mlir][Linalg] SplitReduction implementation without tensor::ExpandShapeOp

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jun 22 12:08:40 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-22T12:06:58-07:00
New Revision: d5716395792696f2b56a0d4debadd040ee385143

URL: https://github.com/llvm/llvm-project/commit/d5716395792696f2b56a0d4debadd040ee385143
DIFF: https://github.com/llvm/llvm-project/commit/d5716395792696f2b56a0d4debadd040ee385143.diff

LOG: [mlir][Linalg] SplitReduction implementation without tensor::ExpandShapeOp

This revision proposes a different implementation of the SplitReductoin transformation that does
not rely on tensor::ExpandShapeOp.

Previously, a dimension `[k]` would be split into `[k][kk]` via an ExpandShapeOp.
Instead, this revision proposes to rewrite `[k]` into `[factor * k + kk]`.

There are different tradeoffs involved  but the proposed implementation is more general because
the affine rewrite is well-defined. In particular, it works naturally with `?` parallel dimensions and
non-trivial indexing maps.

A further rewrite of `[factor * k + kk]` + ExpandShapeOp is possible as a followup.

Differential Revision: https://reviews.llvm.org/D128266

Added: 
    mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2d8a4986e09d6..8f0dc16d35ab7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -222,6 +222,89 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
   }];
 }
 
+def SplitReductionByScalingOp : 
+  Op<Transform_Dialect, "structured.split_reduction_by_scaling",
+       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+        TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Indicates that the given `target` op should be transformed with the 
+    `splitReductionByScaling` transformation and split factor provided as 
+    attribute.
+
+    Instead of introducing an ExpandShapeOp, this scaling-based implementation 
+    rewrites a reduction dimension `k` into `k * split_factor + kk`.
+    The dimension `kk` is added as an extra parallel dimension to the 
+    intermediate output tensor at position `insert_split_dimension`.
+
+    Consider a minimal example where `k` is reduced: 
+        O(i, j) += I(i, j, k)
+    Assume i=3, j=5, k=128, split_factor=16 and insert_split_dimension=0.
+    The compute is rewritten as: 
+      a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
+      b. O(i, j) += O_i(kk, i, j)
+    The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
+
+    Example:
+
+    ```
+     %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+       outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+    ```
+
+    Is transformed to:
+
+    ```
+     #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
+     #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
+     #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+     #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+     #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+     #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+     %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32>
+     %cst = arith.constant 0.000000e+00 : f32
+     %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
+        tensor<16x32x64xf32>
+     %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1>
+
+     %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
+       iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+       ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
+       outs(%1 : tensor<16x32x64xf32>) {
+         ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
+           %5 = arith.mulf %arg3, %arg4 : f32
+           %6 = arith.addf %arg6, %5 : f32
+           linalg.yield %6 : f32
+     } -> tensor<16x32x64xf32>
+
+     %4 = linalg.generic {indexing_maps = [#map4, #map5],
+       iterator_types = ["parallel", "parallel", "reduction"]}
+       ins(%3 : tensor<16x32x64xf32>)
+       outs(%C : tensor<16x32xf32>) {
+         ^bb0(%arg3: f32, %arg4: f32):
+           %5 = arith.addf %arg3, %arg4 : f32
+           linalg.yield %5 : f32
+     } -> tensor<16x32xf32>
+
+     return %4 : tensor<16x32xf32>
+    ```
+
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   DefaultValuedAttr<I64Attr, "{}">:$split_factor,
+                   DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
+  let results = (outs PDL_Operation:$fill_op,
+                      PDL_Operation:$split_linalg_op,
+                      PDL_Operation:$combining_linalg_op);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+        ::mlir::linalg::LinalgOp target, TransformState &state);
+  }];
+}
+
 def TileOp : Op<Transform_Dialect, "structured.tile",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7e2d58939da1c..78f17c1620ba9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1532,6 +1532,56 @@ FailureOr<SplitReductionResult>
 splitReduction(PatternRewriter &b, LinalgOp op,
                const ControlSplitReductionFn &controlSplitReductionFn);
 
+/// Scaling-based implementation of the split reduction transformation.
+/// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension
+/// `k` into `k * scale + kk`.
+///
+/// Example:
+/// ```
+///  %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
+///    outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+/// ```
+///
+/// Is transformed to:
+///
+/// ```
+///  #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
+///  #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
+///  #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+///  #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+///  #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+///  #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+///  %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32>
+///  %cst = arith.constant 0.000000e+00 : f32
+///  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
+///     tensor<16x32x64xf32>
+///  %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1>
+///
+///  %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
+///    iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+///    ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
+///    outs(%1 : tensor<16x32x64xf32>) {
+///      ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
+///        %5 = arith.mulf %arg3, %arg4 : f32
+///        %6 = arith.addf %arg6, %5 : f32
+///        linalg.yield %6 : f32
+///  } -> tensor<16x32x64xf32>
+///
+///  %4 = linalg.generic {indexing_maps = [#map4, #map5],
+///    iterator_types = ["parallel", "parallel", "reduction"]}
+//     ins(%3 : tensor<16x32x64xf32>)
+///    outs(%C : tensor<16x32xf32>) {
+///      ^bb0(%arg3: f32, %arg4: f32):
+///        %5 = arith.addf %arg3, %arg4 : f32
+///        linalg.yield %5 : f32
+///  } -> tensor<16x32xf32>
+///
+///  return %4 : tensor<16x32xf32>
+/// ```
+FailureOr<SplitReductionResult>
+splitReductionByScaling(PatternRewriter &b, LinalgOp op,
+                        const ControlSplitReductionFn &controlSplitReductionFn);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 87ac693492113..de94f43708fad 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -240,6 +240,22 @@ class AffineMap {
                           getContext());
   }
 
+  /// Returns a new AffineMap with the same number of dims and symbols and one
+  /// less result at `pos`, dropped.
+  AffineMap dropResult(unsigned pos) {
+    auto exprs = llvm::to_vector<4>(getResults());
+    exprs.erase(exprs.begin() + pos);
+    return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
+  }
+
+  /// Returns a new AffineMap with the same number of dims and symbols and an
+  /// extra result inserted at `pos`.
+  AffineMap insertResult(AffineExpr expr, unsigned pos) {
+    auto exprs = llvm::to_vector<4>(getResults());
+    exprs.insert(exprs.begin() + pos, expr);
+    return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
+  }
+
   /// Folds the results of the application of an affine map on the provided
   /// operands to a constant if possible.
   LogicalResult constantFold(ArrayRef<Attribute> operandConstants,

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 60c61cdd56a76..4bdc10a25023f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -249,6 +249,16 @@ class RankedTensorType::Builder {
     return *this;
   }
 
+  /// Insert a val into shape @pos.
+  Builder &insertDim(int64_t val, unsigned pos) {
+    assert(pos <= shape.size() && "overflow");
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    storage.insert(storage.begin() + pos, val);
+    shape = {storage.data(), storage.size()};
+    return *this;
+  }
+
   operator RankedTensorType() {
     return RankedTensorType::get(shape, elementType, encoding);
   }

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f8ce4701ab74d..e495a3ddfd483 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -421,6 +421,28 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
                                   splitResult->resultCombiningLinalgOp};
 }
 
+//===----------------------------------------------------------------------===//
+// SplitReductionByScalingOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<SmallVector<Operation *>>
+transform::SplitReductionByScalingOp::applyToOne(LinalgOp target,
+                                                 TransformState &state) {
+  ControlSplitReductionFn splitFn = [&](LinalgOp) {
+    return std::pair<int64_t, unsigned>(getSplitFactor(),
+                                        getInsertSplitDimension());
+  };
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<SplitReductionResult> splitResult =
+      splitReductionByScaling(rewriter, target, splitFn);
+  if (failed(splitResult))
+    return getOperation()->emitError("failed to apply");
+  return SmallVector<Operation *>{splitResult->fillOp,
+                                  splitResult->splitLinalgOp,
+                                  splitResult->resultCombiningLinalgOp};
+}
+
 //===----------------------------------------------------------------------===//
 // TileOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 226b35d4495ce..8834000edd69b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -19,13 +19,14 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
 
 /// Return the identity numeric value associated to the give op.
-static Optional<Attribute> getIdentity(Operation *op) {
+static Attribute getNeutralElement(Operation *op) {
   // Builder only used as helper for attribute creation.
   OpBuilder b(op->getContext());
   Type resultType = op->getResult(0).getType();
@@ -41,7 +42,7 @@ static Optional<Attribute> getIdentity(Operation *op) {
     if (isa<arith::MinFOp>(op))
       return b.getFloatAttr(resultType,
                             llvm::APFloat::getLargest(semantic, true));
-    return llvm::None;
+    return Attribute();
   }
   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
     return b.getIntegerAttr(resultType, 0);
@@ -53,7 +54,7 @@ static Optional<Attribute> getIdentity(Operation *op) {
     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
   if (isa<arith::MulIOp>(op))
     return b.getIntegerAttr(resultType, 1);
-  return llvm::None;
+  return Attribute();
 }
 
 FailureOr<LinalgOp> mlir::linalg::splitReduction(
@@ -84,7 +85,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
 
   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
   int64_t ratio = control.first;
-  unsigned insertDimIndex = control.second;
+  unsigned insertSplitDimension = control.second;
   if (ratio <= 1)
     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
 
@@ -95,7 +96,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
   int64_t reductionDimSize = loopRanges[reductionDim];
   if (reductionDimSize == ShapedType::kDynamicSize ||
-      reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size())
+      reductionDimSize % ratio != 0 ||
+      insertSplitDimension >= loopRanges.size())
     return b.notifyMatchFailure(
         op, "Reduction dimension not divisible by split ratio");
 
@@ -105,7 +107,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
 
   Operation *reductionOp = combinerOps[0];
-  Optional<Attribute> identity = getIdentity(reductionOp);
+  Attribute identity = getNeutralElement(reductionOp);
   if (!identity)
     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
 
@@ -125,13 +127,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
         newShape.push_back(ratio);
         newShape.push_back(op.getShape(operand)[idx] / ratio);
         reassociation.push_back({index++, index++});
-        exprs.push_back(b.getAffineDimExpr(insertDimIndex));
+        exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
         exprs.push_back(
-            b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+            b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
         continue;
       }
       newShape.push_back(op.getShape(operand)[idx]);
-      exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+      exprs.push_back(
+          b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
       reassociation.push_back({index++});
     }
     newMaps.push_back(
@@ -157,20 +160,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   SmallVector<AffineExpr> outputExpr;
   for (unsigned idx :
        llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
-    if (idx == insertDimIndex) {
+    if (idx == insertSplitDimension) {
       newOutputShape.push_back(ratio);
-      outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
+      outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
       continue;
     }
-    unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
+    unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
     newOutputShape.push_back(oldShape[oldDim]);
     unsigned dim = oldOutputMap.getDimPosition(oldDim);
     outputExpr.push_back(
-        b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
+        b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
   }
   Value initTensor = b.create<linalg::InitTensorOp>(
       loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
-  Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
+  Value constantOp = b.create<arith::ConstantOp>(loc, identity);
   Value identityTensor =
       b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
           .getResult(0);
@@ -179,7 +182,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
                                    op.getContext()));
   SmallVector<StringRef> newIteratorTypes;
   for (auto &it : llvm::enumerate(op.iterator_types())) {
-    if (insertDimIndex == it.index())
+    if (insertSplitDimension == it.index())
       newIteratorTypes.push_back(getParallelIteratorTypeName());
     newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
   }
@@ -199,7 +202,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   SmallVector<StringRef> reductionIteratorTypes;
   SmallVector<AffineExpr> exprs;
   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
-    if (insertDimIndex == i) {
+    if (insertSplitDimension == i) {
       reductionIteratorTypes.push_back(getReductionIteratorTypeName());
     } else {
       exprs.push_back(b.getAffineDimExpr(i));
@@ -225,6 +228,206 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
                               reduction};
 }
 
+/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
+/// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
+/// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
+/// done as a transform to enable better vectorization.
+static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
+                                   unsigned reductionDimPos,
+                                   int64_t reductionRatio) {
+  auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
+  auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
+  AffineMap map = op.getTiedIndexingMap(&opOperand);
+  AffineMap idMap =
+      AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
+  AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
+  AffineMap composeMap = shiftedIdMap.replace(
+      reductionDim, reductionDim * reductionRatio + reductionDimP1,
+      shiftedIdMap.getNumDims(), /*numSymbols=*/0);
+  return map.compose(composeMap);
+}
+
+static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
+                                   unsigned reductionDimPos, int64_t size) {
+  auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
+  AffineMap map = op.getTiedIndexingMap(&opOperand);
+  AffineMap idMap =
+      AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
+  AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
+  return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
+}
+
+/// Core rewrite implementation.
+FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
+    PatternRewriter &b, LinalgOp op,
+    const ControlSplitReductionFn &controlSplitReductionFn) {
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPoint(op);
+
+  // Matcher part, enforce preconditions.
+  std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
+  int64_t splitFactor = control.first;
+  unsigned insertSplitDimension = control.second;
+  if (splitFactor <= 1)
+    return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
+
+  SmallVector<unsigned> dims;
+  op.getReductionDims(dims);
+  if (dims.empty())
+    return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
+
+  unsigned reductionDimPos = dims[0];
+  SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
+  int64_t reductionDimSize = loopRanges[reductionDimPos];
+  if (reductionDimSize == ShapedType::kDynamicSize ||
+      reductionDimSize % splitFactor != 0 ||
+      insertSplitDimension >= loopRanges.size())
+    return b.notifyMatchFailure(
+        op, "first reduction dimension not divisible by split factor");
+
+  SmallVector<Operation *> combinerOps;
+  if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
+    return b.notifyMatchFailure(op, "cannot match a reduction pattern");
+
+  SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
+      llvm::map_range(combinerOps, [&](Operation *reductionOp) {
+        return getNeutralElement(reductionOp);
+      }));
+  if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
+    return b.notifyMatchFailure(op, "unknown reduction neutral");
+
+  // TODO: relax this when multi-reduction support is available.
+  if (op.getNumOutputs() != neutralElements.size())
+    return b.notifyMatchFailure(op, "expect one reduction per output");
+
+  // Rewrite part.
+  // Step 1. Build the intermediate outputs filled with the proper
+  // neutralElements. Such outputs are of the same shape with an extra dimension
+  // inserted at `insertSplitDimension`.
+  //
+  // Consider a minimal example where `k` is reduced:
+  //     O(i, j) += I(i, j, k)
+  // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
+  // The compute is rewritten as:
+  //   a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
+  //   b. O(i, j) += O_i(kk, i, j)
+  // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
+  Location loc = op->getLoc();
+  MLIRContext *context = op.getContext();
+  // For now assume outputs are 1-1 with reduction neutralElements.
+  // TODO: generalize when multi-reduction support is available.
+  SmallVector<Value> newOutputs;
+  newOutputs.reserve(op.getNumOutputs());
+  SmallVector<linalg::FillOp> fillOps;
+  fillOps.reserve(op.getNumOutputs());
+  for (auto it : llvm::zip(op.outputs(), neutralElements)) {
+    Value rankedTensor = std::get<0>(it);
+    auto t = rankedTensor.getType().cast<RankedTensorType>();
+    RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
+        reductionDimSize / splitFactor, insertSplitDimension);
+    SmallVector<Value> dims =
+        tensor::createDynamicDimValues(b, loc, rankedTensor);
+    Value initTensor = b.create<linalg::InitTensorOp>(
+        loc, dims, newT.getShape(), t.getElementType());
+    Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
+    fillOps.push_back(
+        b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor));
+    newOutputs.push_back(fillOps.back().getResult(0));
+  }
+
+  // Step 2. Reindex / expand indexing maps.
+  // Reindex existing input indexings: k -> k * splitFactor + k'.
+  SmallVector<AffineMap> newMaps;
+  newMaps.reserve(op.getNumInputsAndOutputs() + 1);
+  for (OpOperand *o : op.getInputOperands())
+    newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
+  // Provision a new indexing for the shape-only tensor.
+  auto nDims = op.getNumLoops() + 1;
+  auto redDim = getAffineDimExpr(reductionDimPos, context);
+  auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
+  newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
+  // Expand existing output indexings.
+  // TODO: a subset of these may not reduce along reducePos and should be
+  // reindexed: k -> k * splitFactor + k', when multi-reduction support is
+  // available.
+  for (OpOperand *o : op.getOutputOperands())
+    newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
+                                        reductionDimSize / splitFactor));
+
+  // Step 3. Handle operands.
+  // Compute the new input tensors.
+  auto newInputs = llvm::to_vector<4>(op.inputs());
+  // Add a single shape-only tensor to carry the dimensions without resorting to
+  // more complex inversions.
+  newInputs.push_back(b.create<linalg::InitTensorOp>(
+      loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
+      b.getIntegerType(1)));
+  // Output tensors are already good to go.
+
+  // Step 4. Create the new op matching the original op with an extra parallel
+  // dimension.
+  SmallVector<StringRef> iteratorTypes =
+      llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
+  iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
+                       getParallelIteratorTypeName());
+  GenericOp genericOp =
+      b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
+                          newOutputs, newMaps, iteratorTypes);
+  b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
+                       genericOp.region().begin());
+  genericOp.region().front().insertArgument(reductionDimPos,
+                                            b.getIntegerType(1), loc);
+
+  // Step 5. Create new reduction ops that only reduce the newly added
+  // dimensions from the previous op.
+  // For now assume outputs are 1-1 with reduction ops.
+  // TODO: a subset of these may not reduce in the first place and do not
+  // require a new op, when multi-reduction support is available.
+  // TODO: all results can be handled in a single GenericOp, when
+  // multi-reduction support is available.
+  SmallVector<LinalgOp> results;
+  for (auto it :
+       llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) {
+    Value reindexedOutput = std::get<0>(it);
+    Value originalOutput = std::get<1>(it);
+    auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
+    Operation *combinerOp = std::get<2>(it);
+
+    AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
+    SmallVector<AffineMap> indexingMaps = {
+        map, map.dropResult(insertSplitDimension)};
+    SmallVector<StringRef> reductionIteratorTypes(
+        originalOutputType.getRank() + 1, getParallelIteratorTypeName());
+    reductionIteratorTypes[insertSplitDimension] =
+        getReductionIteratorTypeName();
+
+    // clang-format off
+    auto reductionOp = b.create<GenericOp>(
+        loc,
+        originalOutputType,
+        reindexedOutput,
+        originalOutput,
+        indexingMaps,
+        reductionIteratorTypes,
+        [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
+          Operation *clonedReductionOp = b.clone(*combinerOp);
+          clonedReductionOp->setOperand(0, bbArgs[0]);
+          clonedReductionOp->setOperand(1, bbArgs[1]);
+          b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
+        });
+    // clang-format on
+
+    results.push_back(reductionOp);
+  }
+
+  // TODO: extend when multi-reduction support is available.
+  assert(fillOps.size() == results.size() && results.size() == 1);
+  b.replaceOp(op, results.front()->getResults());
+  return SplitReductionResult{fillOps.front(),
+                              cast<LinalgOp>(genericOp.getOperation()),
+                              results.front()};
+}
+
 namespace {
 
 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
new file mode 100644
index 0000000000000..85ab597d61c18
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+
+// CHECK-LABEL: func.func @matmul_split
+func.func @matmul_split(%A : tensor<?x256xf32>, %B: tensor<256x32xf32>, %C: tensor<?x32xf32>) -> tensor<?x32xf32> {
+
+  //      CHECK: linalg.generic 
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+  // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<?x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
+  // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<?x32x64xf32>) {
+
+  //      CHECK: linalg.generic 
+  // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+  // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<?x32x64xf32>)
+  // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<?x32xf32>) {
+  %0 = linalg.matmul ins(%A, %B: tensor<?x256xf32>, tensor<256x32xf32>)
+                    outs(%C: tensor<?x32xf32>) -> tensor<?x32xf32>
+  return %0: tensor<?x32xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2}
+  }
+}


        


More information about the Mlir-commits mailing list