[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to sequence of FMAs (PR #163382)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Oct 27 07:25:31 PDT 2025


================
@@ -0,0 +1,659 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 (for selective batch or batch-reduce matmul patterns) and
+// BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Enum to represent the type of matmul operation
+enum class MatMulType { Batch, BatchReduce, Others };
+
+static FailureOr<SmallVector<scf::ForOp>>
+getTiledMatmulLoopNest(vector::ContractionOp contractOp,
+                       MatMulType matmulType) {
+  SmallVector<scf::ForOp> list;
+  Operation *current = contractOp;
+  unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3;
+
+  // It is register tiled loop structure on batch (or reduce) matmul
+  // (M->N->(reduce)->K).
+  for (unsigned int i = 0; i < dimCount; i++) {
+    Operation *parent = current->getParentOfType<scf::ForOp>();
+    if (!parent)
+      return failure();
+    list.push_back(dyn_cast<scf::ForOp>(parent));
+    current = parent;
+  }
+  return list;
+}
+
+static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
+    SmallVector<scf::ForOp> loops, SmallVector<memref::SubViewOp> subviews,
+    MatMulType matmulType) {
+  auto subviewOpLhsOffsets = subviews[0].getOffsets();
+  auto subviewOpRhsOffsets = subviews[1].getOffsets();
+  auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+  if (matmulType == MatMulType::BatchReduce) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivReduction = loops[1].getInductionVar();
+    if (ivReduction != subviewOpLhsOffsets[0] ||
+        ivReduction != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[2].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+      return failure();
+
+    Value ivM = loops[3].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  if (matmulType == MatMulType::Batch) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[1].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivM = loops[2].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  return success();
+}
+
+static SmallVector<Value>
+loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
+                          Type elementType, unsigned int M, unsigned int N,
+                          unsigned int vectorSize, Value subviewOpAcc) {
+
+  SmallVector<Value> accumulators;
+
+  // Initialize local variable on assumption that M tile is larger than N
+  unsigned int outerBound = M;
+  unsigned int innerBound = N;
+
+  unsigned int outerStep = 1;
+  unsigned int innerStep = vectorSize;
+
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+  }
+
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
+      Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
+      Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+      if (isNTileLarge) {
+        indexOp_A = indexOp_B;
+        indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
+      }
+
+      auto valueCRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+          ValueRange{indexOp_A, indexOp_B});
+      accumulators.push_back(valueCRow);
+    }
+  }
+
+  return accumulators;
+}
+
+// This function takes matrices A, B, and C (represented as vectors)
+// and generates equivalent target-specific nanokernels.
+// It returns the final accumulator as output.
+// Based on the M tile, N tile, and vector size, it generates optimized
+// nanokernels under the condition that the reduction and K dimension
+// of the input matrices are equal to 1.
+//
+// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
+// size, N tile size, Vector size.
+//
+// Output:
+// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16.
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1]
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
+//
+// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16.
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1]
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  load_B2 = load B[32-47] into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5]
+//
+// return o/p_Acc;
+SmallVector<Value>
+generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
+                    unsigned int vectorSize, unsigned int vnni, unsigned int M,
+                    unsigned int N, ValueRange acc, Value matA, Value matB,
+                    MatMulType matmulType) {
+
+  SmallVector<Value> accumulators;
+  SmallVector<Value> matLoad;
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+  // Start with assumption that M tile size is smaller and create  the
+  // helper variables
+  unsigned int outerBound = M;
+  unsigned int outerStep = 1;
+
+  unsigned int innerBound = N;
+  unsigned int innerStep = vectorSize;
+
+  Value outerMatrix = matA;
+  Value innerMatrix = matB;
+
+  unsigned int outerVectSize = vnni;
+  unsigned int innerVectSize = vectorSize;
+
+  unsigned int fmaBound = M;
+
+  // update helper variables if N tile size is smaller
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (!isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+
+    outerMatrix = matB;
+    innerMatrix = matA;
+
+    outerVectSize = vectorSize;
+    innerVectSize = vnni;
+
+    fmaBound = N / vectorSize;
+  }
+
+  // Load all the element of A or B matrix
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+    Value valueRow;
+
+    if (isNTileLarge) {
+
+      // With the assumption as batch-reduce matmul initialize reduction, M, and
+      // K dimension.
+      SmallVector<Value> index = {c0, indexOp_i, c0};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      // With the assumption as batch-reduce matmul initialize reduction, K, and
+      // N dimension.
+      SmallVector<Value> index = {c0, c0, indexOp_i};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load.
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+    }
+
+    matLoad.push_back(valueRow);
+  }
+
+  // Load elements of A/B Matrix one at a time and compute FMA
+  for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+    Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+    Value valueRow;
+
+    if (!isNTileLarge) {
+      SmallVector<Value> index = {c0, indexOp_j, c0};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, ValueRange(index));
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      SmallVector<Value> index = {c0, c0, indexOp_j};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, index);
+    }
+
+    // FMAs
+    for (unsigned int i = 0; i < fmaBound; i = i + 1) {
+      auto fmaOdd =
+          vector::FMAOp::create(rewriter, loc, matLoad[i], valueRow, acc[k]);
+      k++;
+      accumulators.push_back(fmaOdd);
+    }
+  }
+
+  return accumulators;
+}
+
+// Function to re-create K dimension loop with accumulator as IterArgs for
+// lowering a batch-reduce vector contraction to a system specific nanokernels.
+scf::ForOp createGEMMLoopsWithAccAsIterArgs(
+    RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+    vector::TransferReadOp vectorReadOpLhs,
+    vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp,
+    Type elementType, unsigned int vectorSize, unsigned int vnni,
+    unsigned int M, unsigned int N, ValueRange iterArgsNewReductionForOp,
+    MatMulType matmulType) {
+  auto newKForOp = scf::ForOp::create(
+      rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+      kForOp.getStep(), iterArgsNewReductionForOp,
+      [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+          Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+        IRMapping mapping;
+        mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(1),
+                    ivNewReductionForOp);
+        mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(3),
+                    ivNewKForOp);
+        auto lhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+        IRMapping rhsMapping;
+        rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+                       ivNewReductionForOp);
+        rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(2),
+                       ivNewKForOp);
+        auto rhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+        auto evenFMAs = generateNanokernels(
+            rewriter, kForOp.getLoc(), elementType, vectorSize, vnni, M, N,
+            iterArgsNewKForOp, lhsClone->getResult(0), rhsClone->getResult(0),
+            matmulType);
+
+        scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
+      });
+
+  return newKForOp;
+}
+
+// Function to re-create K dimension loop with accumulator as IterArgs for
+// lowering a batch vector contraction to a system specific nanokernels.
+scf::ForOp createGEMMLoopsWithAccAsIterArgs(
+    RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+    vector::TransferReadOp vectorReadOpLhs,
+    vector::TransferReadOp vectorReadOpRhs, Type elementType,
+    unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N,
+    ValueRange iterArgsNewReductionForOp, MatMulType matmulType) {
+
+  auto newKForOp = scf::ForOp::create(
+      rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+      kForOp.getStep(), iterArgsNewReductionForOp,
+      [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+          Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+        IRMapping mapping;
+        mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(2),
+                    ivNewKForOp);
+        auto lhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+        IRMapping rhsMapping;
+        rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+                       ivNewKForOp);
+        auto rhsClone = rewriterNewKForOp.clone(
+            *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+        auto evenFMAs =
+            generateNanokernels(rewriter, loc, elementType, vectorSize, vnni, M,
+                                N, iterArgsNewKForOp, lhsClone->getResult(0),
+                                rhsClone->getResult(0), matmulType);
+
+        scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
+      });
+
+  return newKForOp;
+}
+
+Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc,
+                                     VectorType vecType,
+                                     SmallVector<Value> FMAs, Value accVec,
+                                     unsigned int vectorSize, unsigned int M,
+                                     unsigned int N) {
+
+  auto strides = rewriter.getI64ArrayAttr({1});
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (isNTileLarge) {
+    for (unsigned int j = 0, k = 0; j < (N / vectorSize); j++) {
+      for (unsigned int i = 0; i < M; i++) {
+        unsigned int off = (j * vectorSize) + (i * N);
+        auto offsets = rewriter.getI64ArrayAttr({off});
+        accVec = vector::InsertStridedSliceOp::create(
+            rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
+        k++;
+      }
+    }
+
+  } else {
+    for (unsigned int i = 0, k = 0; i < M * N; i = i + vectorSize) {
+      auto offsets = rewriter.getI64ArrayAttr({i});
+      accVec = vector::InsertStridedSliceOp::create(
+          rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
+      k++;
+    }
+  }
+  return accVec;
+}
+
+// Rewriter pattern for vector.contract operation.
+// Input: vector.contract with tiled dimensions (batch or batch-matmul)
+// Matching Pattern:
----------------
adam-smnk wrote:

Are prologue and epilogue ops supported in any way?

https://github.com/llvm/llvm-project/pull/163382


More information about the Mlir-commits mailing list