[Mlir-commits] [mlir] 178f9bd - [mlir][Linalg] Uniformize SplitReduction transforms and add option to use Bufferization::AllocTensor

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 30 03:32:28 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-30T03:32:23-07:00
New Revision: 178f9bd63c9e0a207acc3ac2461ba53b99576e69

URL: https://github.com/llvm/llvm-project/commit/178f9bd63c9e0a207acc3ac2461ba53b99576e69
DIFF: https://github.com/llvm/llvm-project/commit/178f9bd63c9e0a207acc3ac2461ba53b99576e69.diff

LOG: [mlir][Linalg] Uniformize SplitReduction transforms and add option to use Bufferization::AllocTensor

This revision merges the 2 split_reduction transforms and adds extra control by using attributes.

SplitReduction is known to require a concrete additional buffer to store tempoaray information.
Add an option to introduce a `bufferization.alloc_tensor` instead of `linalg.init_tensor`.
This behaves better with subset-based tiling and bufferization.

Differential Revision: https://reviews.llvm.org/D128722

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f3e42cefb2d45..461388dd61af3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -164,8 +164,24 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
     reduction into a parallel and reduction dimension. 
     A new `linalg.generic` op is created to perform the rest of the reduction. 
     
-    Example:
-    
+    The transformation supports 
diff erent configurations attributes:
+      - split_factor: the factor by which to split (i.e. the size of the 
+        remaining reduction after splitting).
+      - insert_split_dimension: the dimension in the temporary tensor into 
+        which the new parallel dimension is inserted.
+      - use_scaling_algorithm: whether to use a scaling based formulation that 
+        does not create an ExpandShapeOp (default: do not use scaling)
+      - use_alloc: whether to use an alloc op to allocate the temporary 
+        tensor (default: do not use alloc op)
+
+    This op returns 4 handles to:
+      - the init op (or tensor_alloc op if use_alloc = true), 
+      - the fill op used to initialize the neutral element, 
+      - the split op and 
+      - the result-combining op.
+
+    Example (default: use_scaling_algorithm = false, use_alloc = false):
+    ====================================================================
     ```
       %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
                                             affine_map<(d0) -> ()>],
@@ -178,7 +194,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
       } -> tensor<f32>
     ```
     
-    To:
+    is split into:
     
     ```
       %cst = arith.constant 0.000000e+00 : f32
@@ -203,34 +219,8 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
       } -> tensor<f32>
     ```
 
-    This op returns handles to the fill op used to initialize the neutral 
-    element, the split op and the result-combining op.
-  }];
-
-  let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<I64Attr, "{}">:$split_factor,
-                   DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
-  let results = (outs PDL_Operation:$fill_op,
-                      PDL_Operation:$split_linalg_op,
-                      PDL_Operation:$combining_linalg_op);
-
-  let assemblyFormat = "$target attr-dict";
-
-  let extraClassDeclaration = [{
-    ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
-        ::mlir::linalg::LinalgOp target, TransformState &state);
-  }];
-}
-
-def SplitReductionByScalingOp : 
-  Op<Transform_Dialect, "structured.split_reduction_by_scaling",
-       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-        TransformEachOpTrait, TransformOpInterface]> {
-  let description = [{
-    Indicates that the given `target` op should be transformed with the 
-    `splitReductionByScaling` transformation and split factor provided as 
-    attribute.
-
+    Example (use_scaling_algorithm = true, use_alloc = true):
+    =========================================================
     Instead of introducing an ExpandShapeOp, this scaling-based implementation 
     rewrites a reduction dimension `k` into `k * split_factor + kk`.
     The dimension `kk` is added as an extra parallel dimension to the 
@@ -287,12 +277,13 @@ def SplitReductionByScalingOp :
 
      return %4 : tensor<16x32xf32>
     ```
-
   }];
 
   let arguments = (ins PDL_Operation:$target,
                    DefaultValuedAttr<I64Attr, "{}">:$split_factor,
-                   DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
+                   DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
+                   UnitAttr:$use_scaling_algorithm,
+                   UnitAttr:$use_alloc);
   let results = (outs PDL_Operation:$fill_op,
                       PDL_Operation:$split_linalg_op,
                       PDL_Operation:$combining_linalg_op);

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 78f17c1620ba9..6b3230ade0033 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1474,7 +1474,8 @@ using ControlSplitReductionFn =
 void populateSplitReductionPattern(
     RewritePatternSet &patterns,
     const ControlSplitReductionFn &controlSplitReductionFn,
-    const LinalgTransformationFilter &f = LinalgTransformationFilter());
+    const LinalgTransformationFilter &f = LinalgTransformationFilter(),
+    bool useAlloc = false);
 
 /// Apply transformation to split the single linalg op reduction into a parallel
 /// and reduction dimension. Then create a new linalg.generic op doing the rest
@@ -1518,19 +1519,21 @@ void populateSplitReductionPattern(
 FailureOr<LinalgOp>
 splitReduction(PatternRewriter &b, LinalgOp op,
                const ControlSplitReductionFn &controlSplitReductionFn,
-               const LinalgTransformationFilter &f);
+               const LinalgTransformationFilter &f, bool useAlloc = false);
 
 /// Filterless version of the above.
 /// Returns both the new linalg ops as well as the fillOp needed to initialize
 /// the temporary expanded tensor with the proper neutral element.
 struct SplitReductionResult {
+  Operation *initOrAlloc;
   FillOp fillOp;
   LinalgOp splitLinalgOp;
   LinalgOp resultCombiningLinalgOp;
 };
 FailureOr<SplitReductionResult>
 splitReduction(PatternRewriter &b, LinalgOp op,
-               const ControlSplitReductionFn &controlSplitReductionFn);
+               const ControlSplitReductionFn &controlSplitReductionFn,
+               bool useAlloc = false);
 
 /// Scaling-based implementation of the split reduction transformation.
 /// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension
@@ -1580,7 +1583,8 @@ splitReduction(PatternRewriter &b, LinalgOp op,
 /// ```
 FailureOr<SplitReductionResult>
 splitReductionByScaling(PatternRewriter &b, LinalgOp op,
-                        const ControlSplitReductionFn &controlSplitReductionFn);
+                        const ControlSplitReductionFn &controlSplitReductionFn,
+                        bool useAlloc = false);
 
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e495a3ddfd483..b644848c53172 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -413,29 +413,9 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
   SimpleRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<SplitReductionResult> splitResult =
-      splitReduction(rewriter, target, splitFn);
-  if (failed(splitResult))
-    return getOperation()->emitError("failed to apply");
-  return SmallVector<Operation *>{splitResult->fillOp,
-                                  splitResult->splitLinalgOp,
-                                  splitResult->resultCombiningLinalgOp};
-}
-
-//===----------------------------------------------------------------------===//
-// SplitReductionByScalingOp
-//===----------------------------------------------------------------------===//
-
-FailureOr<SmallVector<Operation *>>
-transform::SplitReductionByScalingOp::applyToOne(LinalgOp target,
-                                                 TransformState &state) {
-  ControlSplitReductionFn splitFn = [&](LinalgOp) {
-    return std::pair<int64_t, unsigned>(getSplitFactor(),
-                                        getInsertSplitDimension());
-  };
-  SimpleRewriter rewriter(getContext());
-  rewriter.setInsertionPoint(target);
-  FailureOr<SplitReductionResult> splitResult =
-      splitReductionByScaling(rewriter, target, splitFn);
+      (getUseScalingAlgorithm())
+          ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
+          : splitReduction(rewriter, target, splitFn, getUseAlloc());
   if (failed(splitResult))
     return getOperation()->emitError("failed to apply");
   return SmallVector<Operation *>{splitResult->fillOp,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 61989f1eb7b29..6eb263a0bf156 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -60,14 +61,14 @@ static Attribute getNeutralElement(Operation *op) {
 FailureOr<LinalgOp> mlir::linalg::splitReduction(
     PatternRewriter &b, LinalgOp op,
     const ControlSplitReductionFn &controlSplitReductionFn,
-    const LinalgTransformationFilter &filter) {
+    const LinalgTransformationFilter &filter, bool useAlloc) {
   if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
       !op.hasOnlyProjectedPermutations())
     return b.notifyMatchFailure(op, "precondition not met");
 
   FailureOr<SplitReductionResult> res =
-      splitReduction(b, op, controlSplitReductionFn);
+      splitReduction(b, op, controlSplitReductionFn, useAlloc);
   if (failed(res))
     return failure();
 
@@ -79,7 +80,7 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
 
 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     PatternRewriter &b, LinalgOp op,
-    const ControlSplitReductionFn &controlSplitReductionFn) {
+    const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(op);
 
@@ -171,11 +172,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
     outputExpr.push_back(
         b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
   }
-  Value initTensor = b.create<linalg::InitTensorOp>(
-      loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
+  Value initOrAllocTensor;
+  if (useAlloc) {
+    initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
+        loc,
+        RankedTensorType::get(newOutputShape,
+                              op.getRegionOutputArgs()[0].getType()),
+        ValueRange{});
+  } else {
+    initOrAllocTensor = b.create<linalg::InitTensorOp>(
+        loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
+  }
   Value constantOp = b.create<arith::ConstantOp>(loc, identity);
   Value identityTensor =
-      b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
+      b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
           .getResult(0);
 
   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
@@ -189,7 +199,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   // Create the new op matching the original op with an extra parallel
   // dimension.
   GenericOp genericOp = b.create<GenericOp>(
-      loc, TypeRange({initTensor.getType()}), newInputs,
+      loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
       ValueRange({identityTensor}), newMaps, newIteratorTypes);
   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
                        genericOp.region().begin());
@@ -223,9 +233,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
       });
   b.replaceOp(op, reduction.getResults());
 
-  return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
-                              cast<LinalgOp>(genericOp.getOperation()),
-                              reduction};
+  return SplitReductionResult{
+      initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
+      cast<LinalgOp>(genericOp.getOperation()), reduction};
 }
 
 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
@@ -260,7 +270,7 @@ static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
 /// Core rewrite implementation.
 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
     PatternRewriter &b, LinalgOp op,
-    const ControlSplitReductionFn &controlSplitReductionFn) {
+    const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(op);
 
@@ -297,7 +307,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
     return b.notifyMatchFailure(op, "unknown reduction neutral");
 
   // TODO: relax this when multi-reduction support is available.
-  if (op.getNumOutputs() != (int)neutralElements.size())
+  if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
     return b.notifyMatchFailure(op, "expect one reduction per output");
 
   // Rewrite part.
@@ -318,6 +328,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   // TODO: generalize when multi-reduction support is available.
   SmallVector<Value> newOutputs;
   newOutputs.reserve(op.getNumOutputs());
+  SmallVector<Operation *> initOrAllocTensorOps;
   SmallVector<linalg::FillOp> fillOps;
   fillOps.reserve(op.getNumOutputs());
   for (auto it : llvm::zip(op.outputs(), neutralElements)) {
@@ -327,12 +338,19 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
         reductionDimSize / splitFactor, insertSplitDimension);
     SmallVector<Value> dims =
         tensor::createDynamicDimValues(b, loc, rankedTensor);
-    Value initTensor = b.create<linalg::InitTensorOp>(
-        loc, dims, newT.getShape(), t.getElementType());
+    Value initOrAllocTensor;
+    if (useAlloc) {
+      initOrAllocTensor =
+          b.create<bufferization::AllocTensorOp>(loc, newT, dims);
+    } else {
+      initOrAllocTensor = b.create<linalg::InitTensorOp>(
+          loc, dims, newT.getShape(), t.getElementType());
+    }
     Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
     fillOps.push_back(
-        b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor));
+        b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
     newOutputs.push_back(fillOps.back().getResult(0));
+    initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
   }
 
   // Step 2. Reindex / expand indexing maps.
@@ -423,7 +441,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   // TODO: extend when multi-reduction support is available.
   assert(fillOps.size() == results.size() && results.size() == 1);
   b.replaceOp(op, results.front()->getResults());
-  return SplitReductionResult{fillOps.front(),
+  return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
                               cast<LinalgOp>(genericOp.getOperation()),
                               results.front()};
 }
@@ -434,18 +452,21 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
   LinalgSplitReduction(MLIRContext *context,
                        ControlSplitReductionFn controlSplitReductionFn,
-                       LinalgTransformationFilter f, PatternBenefit benefit = 1)
+                       LinalgTransformationFilter f, bool useAlloc = false,
+                       PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
         controlSplitReductionFn(std::move(controlSplitReductionFn)),
-        filter(std::move(f)) {}
+        useAlloc(useAlloc), filter(std::move(f)) {}
 
   LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    return splitReduction(rewriter, op, controlSplitReductionFn, filter);
+    return splitReduction(rewriter, op, controlSplitReductionFn, filter,
+                          useAlloc);
   }
 
 private:
   ControlSplitReductionFn controlSplitReductionFn;
+  bool useAlloc;
   LinalgTransformationFilter filter;
 };
 
@@ -454,7 +475,7 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
 void linalg::populateSplitReductionPattern(
     RewritePatternSet &patterns,
     const ControlSplitReductionFn &controlSplitReductionFn,
-    const LinalgTransformationFilter &f) {
+    const LinalgTransformationFilter &f, bool useAlloc) {
   patterns.add<LinalgSplitReduction>(patterns.getContext(),
-                                     controlSplitReductionFn, f);
+                                     controlSplitReductionFn, f, useAlloc);
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
index 85ab597d61c18..572c746d583e5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir
@@ -3,6 +3,7 @@
 // CHECK-LABEL: func.func @matmul_split
 func.func @matmul_split(%A : tensor<?x256xf32>, %B: tensor<256x32xf32>, %C: tensor<?x32xf32>) -> tensor<?x32xf32> {
 
+  //      CHECK: bufferization.alloc_tensor({{.*}}) : tensor<?x32x64xf32>
   //      CHECK: linalg.generic 
   // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
   // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<?x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
@@ -30,6 +31,7 @@ transform.with_pdl_patterns {
   transform.sequence %arg0 {
   ^bb1(%arg1: !pdl.operation):
     %0 = pdl_match @pdl_target in %arg1
-    %1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2}
+    %1:3 = transform.structured.split_reduction %0 
+      { split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc}
   }
 }

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 4c8a9484eb4ee..2c44d2ff83ee2 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -41,6 +42,7 @@ struct TestLinalgTransforms
   void getDependentDialects(DialectRegistry &registry) const override {
     // clang-format off
     registry.insert<AffineDialect,
+                    bufferization::BufferizationDialect,
                     memref::MemRefDialect,
                     scf::SCFDialect,
                     linalg::LinalgDialect,


        


More information about the Mlir-commits mailing list