[Mlir-commits] [mlir] [MLIR][Linalg] Transform pass to optimize and lower vector contract operation to FMA (PR #121748)

Rolf Morel llvmlistbot at llvm.org
Mon Jan 6 09:19:47 PST 2025


================
@@ -0,0 +1,267 @@
+//===- HoistVectorTransfers.cpp ---------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+// Function to retrives vector transfer read operations (Acc, Lhs, and Rhs) from contraction operation. 
+static FailureOr<SmallVector<vector::TransferReadOp>>
+getContractOperands(vector::ContractionOp contractOp) {
+  SmallVector<vector::TransferReadOp> list;
+  for (int i = 0; i < 3; i++) {
+    auto vectorReadOp =
+        contractOp.getOperand(i).getDefiningOp<vector::TransferReadOp>();
+    if (!vectorReadOp)
+      return failure();
+    list.push_back(vectorReadOp);
+  }
+  return list;
+}
+
+// Function to retrive subview from vector transfer read operation.
+static FailureOr<SmallVector<memref::SubViewOp>>
+getReadOperands(SmallVector<vector::TransferReadOp> readOps) {
+  SmallVector<memref::SubViewOp> list;
+  for (vector::TransferReadOp readOp : readOps) {
+    auto subViewOp = readOp.getOperand(0).getDefiningOp<memref::SubViewOp>();
+    if (!subViewOp)
+      return failure();
+    list.push_back(subViewOp);
+  }
+  return list;
+}
+
+// Function to retrive the tiled nested loop structure (m->n->reduction->k) for the contract operation
+static FailureOr<SmallVector<scf::ForOp>>
+getNestedLoop(vector::ContractionOp contractOp) {
+  SmallVector<scf::ForOp> list;
+  Operation *current = contractOp;
+  for (int i = 0; i < 4; i++) {
+    Operation *parent = current->getParentOfType<scf::ForOp>();
+    if (!parent)
+      return failure();
+    list.push_back(dyn_cast<scf::ForOp>(parent));
+    current = parent;
+  }
+  return list;
+}
+
+// Function to check iv of nested loops matches with the subview
+static LogicalResult checkNestedLoop(SmallVector<scf::ForOp> loops,
+                                     SmallVector<memref::SubViewOp> subviews) {
+  auto subviewOpLhsOffsets = subviews[0].getOffsets();
+  auto subviewOpRhsOffsets = subviews[1].getOffsets();
+  auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+  Value ivK = loops[0].getInductionVar();
+  if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+    return failure();
+
+  Value ivReduction = loops[1].getInductionVar();
+  if (ivReduction != subviewOpLhsOffsets[0] ||
+      ivReduction != subviewOpRhsOffsets[0])
+    return failure();
+
+  Value ivN = loops[2].getInductionVar();
+  if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+    return failure();
+
+  Value ivM = loops[3].getInductionVar();
+  if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+    return failure();
+
+  return success();
+}
+
+/// Hoist vector transfer read and write operations for the tiled batch reduce matmul operation
+/// outside the reduction and k-loop.
----------------
rolfmorel wrote:

Can you include the matching conditions in the docstring, .e.g. loop nest of four scf.for + ...?

https://github.com/llvm/llvm-project/pull/121748


More information about the Mlir-commits mailing list