[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to sequence of FMAs (PR #163382)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Oct 27 07:25:31 PDT 2025
================
@@ -0,0 +1,659 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 (for selective batch or batch-reduce matmul patterns) and
+// BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Enum to represent the type of matmul operation
+enum class MatMulType { Batch, BatchReduce, Others };
+
+static FailureOr<SmallVector<scf::ForOp>>
+getTiledMatmulLoopNest(vector::ContractionOp contractOp,
+ MatMulType matmulType) {
+ SmallVector<scf::ForOp> list;
+ Operation *current = contractOp;
+ unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3;
+
+ // It is register tiled loop structure on batch (or reduce) matmul
+ // (M->N->(reduce)->K).
+ for (unsigned int i = 0; i < dimCount; i++) {
+ Operation *parent = current->getParentOfType<scf::ForOp>();
+ if (!parent)
+ return failure();
+ list.push_back(dyn_cast<scf::ForOp>(parent));
+ current = parent;
+ }
+ return list;
+}
+
+static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
+ SmallVector<scf::ForOp> loops, SmallVector<memref::SubViewOp> subviews,
+ MatMulType matmulType) {
+ auto subviewOpLhsOffsets = subviews[0].getOffsets();
+ auto subviewOpRhsOffsets = subviews[1].getOffsets();
+ auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+ if (matmulType == MatMulType::BatchReduce) {
+ Value ivK = loops[0].getInductionVar();
+ if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+ return failure();
+
+ Value ivReduction = loops[1].getInductionVar();
+ if (ivReduction != subviewOpLhsOffsets[0] ||
+ ivReduction != subviewOpRhsOffsets[0])
+ return failure();
+
+ Value ivN = loops[2].getInductionVar();
+ if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+ return failure();
+
+ Value ivM = loops[3].getInductionVar();
+ if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+ return failure();
+ }
+
+ if (matmulType == MatMulType::Batch) {
+ Value ivK = loops[0].getInductionVar();
+ if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+ return failure();
+
+ Value ivN = loops[1].getInductionVar();
+ if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+ return failure();
+
+ Value ivM = loops[2].getInductionVar();
+ if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+ return failure();
+ }
+
+ return success();
+}
+
+static SmallVector<Value>
+loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
+ Type elementType, unsigned int M, unsigned int N,
+ unsigned int vectorSize, Value subviewOpAcc) {
+
+ SmallVector<Value> accumulators;
+
+ // Initialize local variable on assumption that M tile is larger than N
+ unsigned int outerBound = M;
+ unsigned int innerBound = N;
+
+ unsigned int outerStep = 1;
+ unsigned int innerStep = vectorSize;
+
+ bool isNTileLarge = (N / vectorSize) > M;
+ if (isNTileLarge) {
+ outerBound = N;
+ innerBound = M;
+
+ outerStep = vectorSize;
+ innerStep = 1;
+ }
+
+ for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+ for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
+ Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
+ Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+ if (isNTileLarge) {
+ indexOp_A = indexOp_B;
+ indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
+ }
+
+ auto valueCRow = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+ ValueRange{indexOp_A, indexOp_B});
+ accumulators.push_back(valueCRow);
+ }
+ }
+
+ return accumulators;
+}
+
+// This function takes matrices A, B, and C (represented as vectors)
+// and generates equivalent target-specific nanokernels.
+// It returns the final accumulator as output.
+// Based on the M tile, N tile, and vector size, it generates optimized
+// nanokernels under the condition that the reduction and K dimension
+// of the input matrices are equal to 1.
+//
+// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
+// size, N tile size, Vector size.
+//
+// Output:
+// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16.
+// load_B0 = load B[0-15] into vector<16xf32>
+// load_B1 = load B[16-31] into vector<16xf32>
+// bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+// o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1]
+// bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+// o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2]
+// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+// bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+// o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
+// o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
+//
+// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16.
+// bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+// bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+// bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+// load_B0 = load B[0-15] into vector<16xf32>
+// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+// o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1]
+// load_B1 = load B[16-31] into vector<16xf32>
+// o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2]
+// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+// load_B2 = load B[32-47] into vector<16xf32>
+// o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4]
+// o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5]
+//
+// return o/p_Acc;
+SmallVector<Value>
+generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
+ unsigned int vectorSize, unsigned int vnni, unsigned int M,
+ unsigned int N, ValueRange acc, Value matA, Value matB,
+ MatMulType matmulType) {
+
+ SmallVector<Value> accumulators;
+ SmallVector<Value> matLoad;
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ // Start with assumption that M tile size is smaller and create the
+ // helper variables
+ unsigned int outerBound = M;
+ unsigned int outerStep = 1;
+
+ unsigned int innerBound = N;
+ unsigned int innerStep = vectorSize;
+
+ Value outerMatrix = matA;
+ Value innerMatrix = matB;
+
+ unsigned int outerVectSize = vnni;
+ unsigned int innerVectSize = vectorSize;
+
+ unsigned int fmaBound = M;
+
+ // update helper variables if N tile size is smaller
+ bool isNTileLarge = (N / vectorSize) > M;
+ if (!isNTileLarge) {
+ outerBound = N;
+ innerBound = M;
+
+ outerStep = vectorSize;
+ innerStep = 1;
+
+ outerMatrix = matB;
+ innerMatrix = matA;
+
+ outerVectSize = vectorSize;
+ innerVectSize = vnni;
+
+ fmaBound = N / vectorSize;
+ }
+
+ // Load all the element of A or B matrix
+ for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+ Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+ Value valueRow;
+
+ if (isNTileLarge) {
+
+ // With the assumption as batch-reduce matmul initialize reduction, M, and
+ // K dimension.
+ SmallVector<Value> index = {c0, indexOp_i, c0};
+
+ // Remove reduction dimension if it is a batch matmul
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // A Matrix load + broadcast
+ Value row = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(outerVectSize, elementType),
+ outerMatrix, index);
+ valueRow = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+ row);
+ } else {
+
+ // With the assumption as batch-reduce matmul initialize reduction, K, and
+ // N dimension.
+ SmallVector<Value> index = {c0, c0, indexOp_i};
+
+ // Remove reduction dimension if it is a batch matmul
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // B Matrix load.
+ valueRow = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(outerVectSize, elementType),
+ outerMatrix, index);
+ }
+
+ matLoad.push_back(valueRow);
+ }
+
+ // Load elements of A/B Matrix one at a time and compute FMA
+ for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+ Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+ Value valueRow;
+
+ if (!isNTileLarge) {
+ SmallVector<Value> index = {c0, indexOp_j, c0};
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // A Matrix load + broadcast
+ Value row = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(innerVectSize, elementType),
+ innerMatrix, ValueRange(index));
+ valueRow = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+ row);
+ } else {
+
+ SmallVector<Value> index = {c0, c0, indexOp_j};
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // B Matrix load
+ valueRow = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(innerVectSize, elementType),
+ innerMatrix, index);
+ }
+
+ // FMAs
+ for (unsigned int i = 0; i < fmaBound; i = i + 1) {
+ auto fmaOdd =
+ vector::FMAOp::create(rewriter, loc, matLoad[i], valueRow, acc[k]);
+ k++;
+ accumulators.push_back(fmaOdd);
+ }
+ }
+
+ return accumulators;
+}
+
+// Function to re-create K dimension loop with accumulator as IterArgs for
+// lowering a batch-reduce vector contraction to a system specific nanokernels.
+scf::ForOp createGEMMLoopsWithAccAsIterArgs(
+ RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+ vector::TransferReadOp vectorReadOpLhs,
+ vector::TransferReadOp vectorReadOpRhs, Value ivNewReductionForOp,
+ Type elementType, unsigned int vectorSize, unsigned int vnni,
+ unsigned int M, unsigned int N, ValueRange iterArgsNewReductionForOp,
+ MatMulType matmulType) {
+ auto newKForOp = scf::ForOp::create(
+ rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+ kForOp.getStep(), iterArgsNewReductionForOp,
+ [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+ Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+ IRMapping mapping;
+ mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(1),
+ ivNewReductionForOp);
+ mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(3),
+ ivNewKForOp);
+ auto lhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+ ivNewReductionForOp);
+ rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(2),
+ ivNewKForOp);
+ auto rhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+ auto evenFMAs = generateNanokernels(
+ rewriter, kForOp.getLoc(), elementType, vectorSize, vnni, M, N,
+ iterArgsNewKForOp, lhsClone->getResult(0), rhsClone->getResult(0),
+ matmulType);
+
+ scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
+ });
+
+ return newKForOp;
+}
+
+// Function to re-create K dimension loop with accumulator as IterArgs for
+// lowering a batch vector contraction to a system specific nanokernels.
+scf::ForOp createGEMMLoopsWithAccAsIterArgs(
+ RewriterBase &rewriter, Location loc, scf::ForOp kForOp,
+ vector::TransferReadOp vectorReadOpLhs,
+ vector::TransferReadOp vectorReadOpRhs, Type elementType,
+ unsigned int vectorSize, unsigned int vnni, unsigned int M, unsigned int N,
+ ValueRange iterArgsNewReductionForOp, MatMulType matmulType) {
+
+ auto newKForOp = scf::ForOp::create(
+ rewriter, kForOp.getLoc(), kForOp.getLowerBound(), kForOp.getUpperBound(),
+ kForOp.getStep(), iterArgsNewReductionForOp,
+ [&](OpBuilder &rewriterNewKForOp, Location locNewKForOp,
+ Value ivNewKForOp, ValueRange iterArgsNewKForOp) {
+ IRMapping mapping;
+ mapping.map(vectorReadOpLhs.getBase().getDefiningOp()->getOperand(2),
+ ivNewKForOp);
+ auto lhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpLhs.getBase().getDefiningOp(), mapping);
+
+ IRMapping rhsMapping;
+ rhsMapping.map(vectorReadOpRhs.getBase().getDefiningOp()->getOperand(1),
+ ivNewKForOp);
+ auto rhsClone = rewriterNewKForOp.clone(
+ *vectorReadOpRhs.getBase().getDefiningOp(), rhsMapping);
+
+ auto evenFMAs =
+ generateNanokernels(rewriter, loc, elementType, vectorSize, vnni, M,
+ N, iterArgsNewKForOp, lhsClone->getResult(0),
+ rhsClone->getResult(0), matmulType);
+
+ scf::YieldOp::create(rewriterNewKForOp, locNewKForOp, evenFMAs);
+ });
+
+ return newKForOp;
+}
+
+Value mergeAccumulatedVectorAsMatrix(RewriterBase &rewriter, Location loc,
+ VectorType vecType,
+ SmallVector<Value> FMAs, Value accVec,
+ unsigned int vectorSize, unsigned int M,
+ unsigned int N) {
+
+ auto strides = rewriter.getI64ArrayAttr({1});
+ bool isNTileLarge = (N / vectorSize) > M;
+ if (isNTileLarge) {
+ for (unsigned int j = 0, k = 0; j < (N / vectorSize); j++) {
+ for (unsigned int i = 0; i < M; i++) {
+ unsigned int off = (j * vectorSize) + (i * N);
+ auto offsets = rewriter.getI64ArrayAttr({off});
+ accVec = vector::InsertStridedSliceOp::create(
+ rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
+ k++;
+ }
+ }
+
+ } else {
+ for (unsigned int i = 0, k = 0; i < M * N; i = i + vectorSize) {
+ auto offsets = rewriter.getI64ArrayAttr({i});
+ accVec = vector::InsertStridedSliceOp::create(
+ rewriter, loc, vecType, FMAs[k], accVec, offsets, strides);
+ k++;
+ }
+ }
+ return accVec;
+}
+
+// Rewriter pattern for vector.contract operation.
+// Input: vector.contract with tiled dimensions (batch or batch-matmul)
+// Matching Pattern:
+// scf.for (0 to M) step m_tile {
+// scf.for (0 to N) step n_tile {
+// - Subview of Accumulator matrix - eg., acc : memref<m_tilexn_tilexf32>
+// - %read = vector.transfer_read memref<m_tilexn_tilexf32> to
+// vector<m_tilexn_tilexf32> %1 = scf.for (0 to reduce)
+// iter_args_reduce=%read step reduce_tile {
+// %2 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile {
+// - Subview of A and B matrix
+// - Vector transfer read of A and B
+// - %acc = Vector.contract %read_A %read_B %iter_args_k
+// scf.yield %acc
+// }
+// scf.yield %2
+// }
+// vector.transfer_write %2 into accmulator matrix
+// }
+// }
+//
+//
+// Rewrite IR:
+// scf.for (0 to M) step m_tile {
+// scf.for (0 to N) step n_tile {
+// - Subview of Accumulator matrix - eg., acc : memref<m_tilexn_tilexf32>
+// - %a = (n_tile / vector_size) * m_tile;
+// // load the accumulator matrix as vector
+// - %0 = load acc[0][0-15] into vector<16xf32>
+// - %1 = load acc[0][16-31] into vector<16xf32>
+// - %2 = load acc[1][0-15] into vector<16xf32>
+// .
+// .
+// .
+// - %a = load acc[m_tile-1][*-n_tile-1] into vector<16xf32>
+// %3 = scf.for (0 to reduce) iter_args_reduce=%0 to %a step reduce_tile {
+// %4 scf.for (0 to K) iter_args_k = %iter_args_reduce step k_tile {
+// - emit nano kernels (as shown in commnets above
+// generateNanokernels function)
+// scf.yield %acc[0] to %acc[a-1]
+// }
+// scf.yield %4: [0] to [a-1]
+// }
+// %5 = vector.insert %3: [0] to [a-1] into vector<m_tilexn_tilexf32>
+// vector.transfer_write %5 into accmulator matrix
+// }
+// }
+struct VectorContractNanokernelLowering
+ : public OpRewritePattern<vector::ContractionOp> {
+ VectorContractNanokernelLowering(MLIRContext *context,
+ std::optional<unsigned> vecSize)
+ : OpRewritePattern<vector::ContractionOp>(context),
+ userVectorSize(vecSize) {}
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ auto loc = contractOp.getLoc();
+
+ unsigned int vectorSize = 8;
+ if (userVectorSize)
+ vectorSize = *userVectorSize;
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ }
+
+ SmallVector<vector::IteratorType> contractIteratorTypes =
+ contractOp.getIteratorTypesArray();
+
+ unsigned int reductionCount =
+ std::count(contractIteratorTypes.begin(), contractIteratorTypes.end(),
+ vector::IteratorType::reduction);
+
+ MatMulType matmulType = MatMulType::Others;
+
+ if (reductionCount == 1)
+ matmulType = MatMulType::Batch;
+
+ if (reductionCount == 2)
+ matmulType = MatMulType::BatchReduce;
+
+ if ((matmulType != MatMulType::BatchReduce) &&
+ (matmulType != MatMulType::Batch))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Expects batch-reduce or batch matmuls");
+
+ // Get the M, N, K, and batch-reduce loops
+ auto loops = getTiledMatmulLoopNest(contractOp, matmulType);
+ if (failed(loops))
+ return rewriter.notifyMatchFailure(
+ contractOp, "Invalid loop nest in contract pattern");
+
+ auto nestedLoops = *loops;
+ scf::ForOp kForOp = nestedLoops[0];
+ scf::ForOp reductionForOp;
+
+ if (contractOp.getAcc().getDefiningOp<vector::TransferReadOp>()) {
+ return rewriter.notifyMatchFailure(
+ contractOp, "The Accumulator matrix should be hoisted outside the K "
+ "or reduction loop");
+ }
+
+ vector::TransferReadOp vectorReadOpAcc;
+
+ if (matmulType == MatMulType::BatchReduce) {
+ reductionForOp = nestedLoops[1];
+ vectorReadOpAcc = reductionForOp.getInitArgs()[0]
+ .getDefiningOp<vector::TransferReadOp>();
+ }
+
+ if (matmulType == MatMulType::Batch) {
+ vectorReadOpAcc =
+ kForOp.getInitArgs()[0].getDefiningOp<vector::TransferReadOp>();
+ }
+
+ auto vectorReadOpLhs =
+ contractOp.getLhs().getDefiningOp<vector::TransferReadOp>();
+ auto vectorReadOpRhs =
+ contractOp.getRhs().getDefiningOp<vector::TransferReadOp>();
+
+ if (!vectorReadOpAcc || !vectorReadOpLhs || !vectorReadOpRhs)
+ return failure();
+
+ auto subviewOpAcc =
+ vectorReadOpAcc.getOperand(0).getDefiningOp<memref::SubViewOp>();
----------------
adam-smnk wrote:
Why are subviews required?
https://github.com/llvm/llvm-project/pull/163382
More information about the Mlir-commits
mailing list