[Mlir-commits] [mlir] [mlir][linalg] Pack matmul pass (PR #89782)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Apr 23 10:07:19 PDT 2024


================
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+  std::optional<int64_t> stride = getConstantIntValue(range.stride);
+  if (!stride || *stride != 1)
+    return std::nullopt;
+  std::optional<int64_t> offset = getConstantIntValue(range.offset);
+  if (!offset)
+    return std::nullopt;
+  std::optional<int64_t> size = getConstantIntValue(range.size);
+  if (!size)
+    return std::nullopt;
+  return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+                                    ArrayRef<OpFoldResult> tiles,
+                                    ArrayRef<size_t> dims) {
+  if (dims.size() != tiles.size() || tiles.empty())
+    return false;
+
+  OpBuilder builder(tileOp);
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<Range> iterationDomain =
+      cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+  for (auto dim : llvm::enumerate(dims)) {
+    if (dim.value() >= iterationDomain.size())
+      return false;
+
+    auto tileSize = getConstantIntValue(tiles[dim.index()]);
+    auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+
+    // If the tile factor or the range are non-constant, the tile size is
+    // considered to be invalid.
+    if (!tileSize || !rangeOnDim)
+      return false;
+
+    // The dimension must be fully divisible by the tile.
+    if (*rangeOnDim % *tileSize != 0)
+      return false;
+  }
+
+  return true;
+}
+
+static FailureOr<linalg::LinalgOp>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+             ArrayRef<OpFoldResult> mnkTiles) {
+  if (!(isa<linalg::MatmulOp>(matmulOp) ||
+        isa<linalg::BatchMatmulOp>(matmulOp))) {
+    return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
+  }
+
+  if (mnkTiles.size() != 3)
+    return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+  if (matmulOp.hasDynamicShape())
----------------
adam-smnk wrote:

I haven't really used and tested the logic on dynamic shapes. So, the first iteration brings the well tested happy paths based on our use cases.
I'm happy to iterate on those restrictions and relax them.

https://github.com/llvm/llvm-project/pull/89782


More information about the Mlir-commits mailing list