[Mlir-commits] [mlir] [mlir][linalg] Block pack matmul pass (PR #89782)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 19:41:34 PDT 2024


================
@@ -0,0 +1,323 @@
+//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Return constant range span or nullopt, otherwise.
+static std::optional<int64_t> getConstantRange(const Range &range) {
+  std::optional<int64_t> stride = getConstantIntValue(range.stride);
+  if (!stride || *stride != 1)
+    return std::nullopt;
+  std::optional<int64_t> offset = getConstantIntValue(range.offset);
+  if (!offset)
+    return std::nullopt;
+  std::optional<int64_t> size = getConstantIntValue(range.size);
+  if (!size)
+    return std::nullopt;
+  return (*size - *offset);
+}
+
+/// Return true if all dimensions are fully divisible by the respective tiles.
+static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
+                                    ArrayRef<OpFoldResult> tiles,
+                                    ArrayRef<int64_t> dims) {
+  if (dims.size() != tiles.size() || tiles.empty())
+    return false;
+
+  FailureOr<ContractionDimensions> contractDims =
+      inferContractionDims(linalgOp);
+  if (failed(contractDims))
+    return false;
+  unsigned batchDimsOffset = contractDims->batch.size();
+
+  // Skip the batch dimension if present.
+  // Offset all dimensions accordingly.
+  SmallVector<int64_t, 3> offsetDims{dims};
+  for (size_t i = 0; i < offsetDims.size(); i++)
+    offsetDims[i] += batchDimsOffset;
+
+  auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
+  OpBuilder builder(tileOp);
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
+
+  for (auto dim : llvm::enumerate(offsetDims)) {
+    if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
+      return false;
+
+    std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
+    std::optional<int64_t> rangeOnDim =
+        getConstantRange(iterationDomain[dim.value()]);
+
+    // If the tile factor or the range are non-constant, the tile size is
+    // considered to be invalid.
+    if (!tileSize || !rangeOnDim)
+      return false;
+
+    // The dimension must be fully divisible by the tile.
+    if (*rangeOnDim % *tileSize != 0)
+      return false;
+  }
+
+  return true;
+}
+
+/// Return failure or packed matmul with one of its operands transposed.
+static FailureOr<PackTransposeResult>
+transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+                      tensor::PackOp packOp, AffineMap operandMap,
+                      ArrayRef<unsigned> blocksStartDimPos,
+                      bool transposeOuterBlocks, bool transposeInnerBlocks,
+                      unsigned outerDimsOffset = 0) {
+  assert(operandMap.getNumDims() >= 4 &&
+         "expected at least 4D prepacked matmul");
+  assert(blocksStartDimPos.size() == 2 &&
+         "expected starting outer and inner block positions");
+
+  // Base dimension positions in 4D packed matmul.
+  unsigned outerBlockPos = 0;
+  unsigned innerBlockPos = 2;
+
+  // Transpose control options define the desired block and element layout.
+  // Block transposition (outer dimensions) or element transposition (inner
+  // dimensions) may not be necessary depending on the original matmul data
+  // layout.
+  bool isOuterTransposed =
+      operandMap.getDimPosition(outerBlockPos + outerDimsOffset) !=
+      blocksStartDimPos.end()[-2];
+  bool isInnerTransposed =
+      operandMap.getDimPosition(innerBlockPos + outerDimsOffset) !=
+      blocksStartDimPos.back();
+
+  // Transpose only the dimensions that need that to conform to the provided
+  // transpotion settings.
+  SmallVector<int64_t> innerPerm{0, 1};
+  if (isInnerTransposed != transposeInnerBlocks)
+    innerPerm = {1, 0};
+  SmallVector<int64_t> outerPerm{0, 1};
+  if (isOuterTransposed != transposeOuterBlocks)
+    outerPerm = {1, 0};
+
+  // Leave the outer dimensions, like batch, unchanged by offsetting all
+  // outer dimensions permutations.
+  SmallVector<int64_t> offsetPerms(outerDimsOffset, 0);
+  for (auto perm : outerPerm)
+    offsetPerms.push_back(perm + outerDimsOffset);
+  outerPerm = offsetPerms;
+
+  FailureOr<PackTransposeResult> packTransposedMatmul =
+      packTranspose(rewriter, packOp, linalgOp,
+                    /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+
+  return packTransposedMatmul;
+}
+
+/// Pack a matmul operation into blocked 4D layout.
+FailureOr<PackResult>
+linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+                        const ControlBlockPackMatmulFn &controlPackMatmul) {
+  if (linalgOp.hasPureBufferSemantics())
+    return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
+
+  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
+  if (!options)
+    return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
+
+  if (options->blockFactors.size() != 3)
+    return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
+
+  SmallVector<OpFoldResult> mnkTiles =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
+
+  // If padding is disabled, make sure that dimensions can be packed cleanly.
+  if (!options->allowPadding &&
+      !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
+    return rewriter.notifyMatchFailure(linalgOp,
+                                       "expect packing full tiles only");
+  }
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  // The op is replaced, we need to set the insertion point after it.
+  rewriter.setInsertionPointAfter(linalgOp);
+
+  // Pack the matmul operation into blocked layout with two levels of
+  // subdivision:
+  //   - major 2D blocks - outer dimensions, consist of minor blocks
+  //   - minor 2D blocks - inner dimensions, consist of scalar elements
+  FailureOr<PackResult> packedMatmul = packMatmulGreedily(
+      rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
+      options->mnkOrder);
+  if (failed(packedMatmul))
+    return failure();
+
+  assert(packedMatmul->packOps.size() == 3 &&
+         "invalid number of pack ops after matmul packing");
+  assert(packedMatmul->unPackOps.size() == 1 &&
+         "invalid number of unpack ops after matmul packing");
+
+  FailureOr<ContractionDimensions> contractDims =
+      inferContractionDims(packedMatmul->packedLinalgOp);
+  if (failed(contractDims))
+    return failure();
+  unsigned batchDimsOffset = contractDims->batch.size();
+
+  auto genericOp =
+      dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
+  SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
+
+  // Transpose LHS matrix according to the options.
+  FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
+      rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
+      contractDims->m, options->lhsTransposeOuterBlocks,
+      options->lhsTransposeInnerBlocks, batchDimsOffset);
+  if (failed(packedLhs))
+    return failure();
+
+  // Update results.
+  packedMatmul->packOps[0] = packedLhs->transposedPackOp;
+  packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
+
+  // Transpose RHS matrix according to the options.
+  FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
+      rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
+      contractDims->k, options->rhsTransposeOuterBlocks,
+      options->rhsTransposeInnerBlocks, batchDimsOffset);
+  if (failed(packedRhs))
+    return failure();
+
+  // Update results.
+  packedMatmul->packOps[1] = packedRhs->transposedPackOp;
+  packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
+
+  return packedMatmul;
+}
+
+namespace {
+template <typename OpTy>
+struct BlockPackMatmul : public OpRewritePattern<OpTy> {
+  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(OpTy linalgOp,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<PackResult> packedMatmul =
+        blockPackMatmul(rewriter, linalgOp, controlFn);
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  ControlBlockPackMatmulFn controlFn;
+};
+
+template <>
+struct BlockPackMatmul<linalg::GenericOp>
+    : public OpRewritePattern<linalg::GenericOp> {
+  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern<linalg::GenericOp>(context, benefit),
+        controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    // Match suitable generics.
+    if (failed(linalg::detail::verifyContractionInterface(
+            linalgOp.getOperation()))) {
+      return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
+    }
+
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    auto infer = [&](MapList m) {
+      return AffineMap::inferFromExprList(m, linalgOp.getContext());
+    };
+
+    AffineExpr i, j, k;
+    bindDims(linalgOp->getContext(), i, j, k);
+    SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
+
+    // For now, only match simple matmuls.
+    if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
+          maps == infer({{k, i}, {k, j}, {i, j}}) ||
+          maps == infer({{i, k}, {j, k}, {i, j}}))) {
+      return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
+    }
+
+    FailureOr<PackResult> packedMatmul =
+        blockPackMatmul(rewriter, linalgOp, controlFn);
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  ControlBlockPackMatmulFn controlFn;
+};
+
+/// Convert linalg matmul ops to block layout and back.
+struct LinalgBlockPackMatmul
+    : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
+  using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(&getContext());
+
+    ControlBlockPackMatmulFn controlFn =
+        [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
+      BlockPackMatmulOptions options;
+      options.blockFactors = SmallVector<int64_t>{*blockFactors};
+      options.allowPadding = allowPadding;
+      options.mnkPaddedSizesNextMultipleOf =
+          SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
+      if (!mnkOrder.empty())
+        options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
+      options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
+      options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
+      options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
+      options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
+      return options;
+    };
+
+    linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void linalg::populateBlockPackMatmulPatterns(
+    RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
+  patterns.add<BlockPackMatmul<linalg::GenericOp>,
----------------
yifeizh2 wrote:

I wonder whether there is better solution than keeping a list of Matmul ops here. Otherwise if we expand linalg ops, we might need to also update the op list here~

https://github.com/llvm/llvm-project/pull/89782


More information about the Mlir-commits mailing list