[Mlir-commits] [mlir] [mlir][linalg] Block pack matmul pass (PR #89782)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue May 7 07:14:01 PDT 2024
================
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+ std::optional<int64_t> stride = getConstantIntValue(range.stride);
+ if (!stride || *stride != 1)
+ return std::nullopt;
+ std::optional<int64_t> offset = getConstantIntValue(range.offset);
+ if (!offset)
+ return std::nullopt;
+ std::optional<int64_t> size = getConstantIntValue(range.size);
+ if (!size)
+ return std::nullopt;
+ return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+ ArrayRef<OpFoldResult> tiles,
+ ArrayRef<size_t> dims) {
+ if (dims.size() != tiles.size() || tiles.empty())
+ return false;
+
+ OpBuilder builder(tileOp);
+ OpBuilder::InsertionGuard guard(builder);
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+ for (auto dim : llvm::enumerate(dims)) {
+ if (dim.value() >= iterationDomain.size())
+ return false;
+
+ auto tileSize = getConstantIntValue(tiles[dim.index()]);
+ auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
----------------
adam-smnk wrote:
Done
https://github.com/llvm/llvm-project/pull/89782
More information about the Mlir-commits
mailing list