[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to sequence of FMAs (PR #163382)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 27 04:00:13 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Arun Thangamani (arun-thmn)
<details>
<summary>Changes</summary>
This PR lowers `vector.contract` (only `batch-reduce` or `batch` matmul) to a sequence of `FMAs` with respect to the `vector_size`.
---
Patch is 48.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163382.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt (+2)
- (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt (+4)
- (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h (+31)
- (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+37)
- (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+12)
- (modified) mlir/lib/Dialect/X86Vector/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt (+17)
- (added) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+61)
- (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp (+659)
- (modified) mlir/lib/RegisterAllExtensions.cpp (+2)
- (added) mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir (+215)
``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..e1d8b8762e799
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// X86Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry ®istry);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..9db2b36a2a8aa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,37 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.x86vector.vector_contract_nanokernel_lowering",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operation can be lowered to target
+ specific nanokernels.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+
+ let assemblyFormat = [{
+ (`vector_size` `=` $vector_size^)? attr-dict
+ }];
+
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..6410c12265f12 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
#include "mlir/IR/Value.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
namespace mlir {
class ImplicitLocOpBuilder;
@@ -79,6 +83,14 @@ struct MaskHelper {
}
};
+//===----------------------------------------------------------------------===//
+// Transforms a scheduled pattern to lower a tiled batch or batch-reduce
+// vector contraction into a sequence of nanokernels.
+// The transformation is tailored to the target machine architecture
+// and guided by the user-specified vector size.
+void populateVectorContractNanokernelLoweringPatterns(
+ RewritePatternSet &patterns, std::optional<unsigned> vectorSize = 8);
+
//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..f4c9f8a05acbc
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+ X86VectorTransformOps.cpp
+
+ DEPENDS
+ MLIRX86VectorTransformOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRVectorDialect
+ MLIRSideEffectInterfaces
+ MLIRTransformDialect
+ MLIRTransformDialectUtils
+ MLIRX86VectorDialect
+ MLIRX86VectorTransforms
+ )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..e003e3ad7cd08
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,61 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::
+ populatePatterns(RewritePatternSet &patterns) {
+ x86vector::populateVectorContractNanokernelLoweringPatterns(patterns,
+ getVectorSize());
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+ : public transform::TransformDialectExtension<
+ X86VectorTransformDialectExtension> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ X86VectorTransformDialectExtension)
+
+ X86VectorTransformDialectExtension() {
+ declareGeneratedDialect<x86vector::X86VectorDialect>();
+ declareGeneratedDialect<LLVM::LLVMDialect>();
+ registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+ >();
+ }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+ DialectRegistry ®istry) {
+ registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..da377763331f2 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
+ NanoKernels.cpp
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
new file mode 100644
index 0000000000000..4d0906a2ec057
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -0,0 +1,659 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 (for selective batch or batch-reduce matmul patterns) and
+// BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Enum to represent the type of matmul operation
+enum class MatMulType { Batch, BatchReduce, Others };
+
+static FailureOr<SmallVector<scf::ForOp>>
+getTiledMatmulLoopNest(vector::ContractionOp contractOp,
+ MatMulType matmulType) {
+ SmallVector<scf::ForOp> list;
+ Operation *current = contractOp;
+ unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3;
+
+ // It is register tiled loop structure on batch (or reduce) matmul
+ // (M->N->(reduce)->K).
+ for (unsigned int i = 0; i < dimCount; i++) {
+ Operation *parent = current->getParentOfType<scf::ForOp>();
+ if (!parent)
+ return failure();
+ list.push_back(dyn_cast<scf::ForOp>(parent));
+ current = parent;
+ }
+ return list;
+}
+
+static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
+ SmallVector<scf::ForOp> loops, SmallVector<memref::SubViewOp> subviews,
+ MatMulType matmulType) {
+ auto subviewOpLhsOffsets = subviews[0].getOffsets();
+ auto subviewOpRhsOffsets = subviews[1].getOffsets();
+ auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+ if (matmulType == MatMulType::BatchReduce) {
+ Value ivK = loops[0].getInductionVar();
+ if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+ return failure();
+
+ Value ivReduction = loops[1].getInductionVar();
+ if (ivReduction != subviewOpLhsOffsets[0] ||
+ ivReduction != subviewOpRhsOffsets[0])
+ return failure();
+
+ Value ivN = loops[2].getInductionVar();
+ if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+ return failure();
+
+ Value ivM = loops[3].getInductionVar();
+ if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+ return failure();
+ }
+
+ if (matmulType == MatMulType::Batch) {
+ Value ivK = loops[0].getInductionVar();
+ if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+ return failure();
+
+ Value ivN = loops[1].getInductionVar();
+ if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+ return failure();
+
+ Value ivM = loops[2].getInductionVar();
+ if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+ return failure();
+ }
+
+ return success();
+}
+
+static SmallVector<Value>
+loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
+ Type elementType, unsigned int M, unsigned int N,
+ unsigned int vectorSize, Value subviewOpAcc) {
+
+ SmallVector<Value> accumulators;
+
+ // Initialize local variable on assumption that M tile is larger than N
+ unsigned int outerBound = M;
+ unsigned int innerBound = N;
+
+ unsigned int outerStep = 1;
+ unsigned int innerStep = vectorSize;
+
+ bool isNTileLarge = (N / vectorSize) > M;
+ if (isNTileLarge) {
+ outerBound = N;
+ innerBound = M;
+
+ outerStep = vectorSize;
+ innerStep = 1;
+ }
+
+ for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+ for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
+ Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
+ Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+ if (isNTileLarge) {
+ indexOp_A = indexOp_B;
+ indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
+ }
+
+ auto valueCRow = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+ ValueRange{indexOp_A, indexOp_B});
+ accumulators.push_back(valueCRow);
+ }
+ }
+
+ return accumulators;
+}
+
+// This function takes matrices A, B, and C (represented as vectors)
+// and generates equivalent target-specific nanokernels.
+// It returns the final accumulator as output.
+// Based on the M tile, N tile, and vector size, it generates optimized
+// nanokernels under the condition that the reduction and K dimension
+// of the input matrices are equal to 1.
+//
+// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
+// size, N tile size, Vector size.
+//
+// Output:
+// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16.
+// load_B0 = load B[0-15] into vector<16xf32>
+// load_B1 = load B[16-31] into vector<16xf32>
+// bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+// o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1]
+// bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+// o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2]
+// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+// bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+// o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
+// o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
+//
+// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16.
+// bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+// bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+// bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+// load_B0 = load B[0-15] into vector<16xf32>
+// o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+// o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1]
+// load_B1 = load B[16-31] into vector<16xf32>
+// o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2]
+// o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+// load_B2 = load B[32-47] into vector<16xf32>
+// o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4]
+// o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5]
+//
+// return o/p_Acc;
+SmallVector<Value>
+generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
+ unsigned int vectorSize, unsigned int vnni, unsigned int M,
+ unsigned int N, ValueRange acc, Value matA, Value matB,
+ MatMulType matmulType) {
+
+ SmallVector<Value> accumulators;
+ SmallVector<Value> matLoad;
+ Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+ // Start with assumption that M tile size is smaller and create the
+ // helper variables
+ unsigned int outerBound = M;
+ unsigned int outerStep = 1;
+
+ unsigned int innerBound = N;
+ unsigned int innerStep = vectorSize;
+
+ Value outerMatrix = matA;
+ Value innerMatrix = matB;
+
+ unsigned int outerVectSize = vnni;
+ unsigned int innerVectSize = vectorSize;
+
+ unsigned int fmaBound = M;
+
+ // update helper variables if N tile size is smaller
+ bool isNTileLarge = (N / vectorSize) > M;
+ if (!isNTileLarge) {
+ outerBound = N;
+ innerBound = M;
+
+ outerStep = vectorSize;
+ innerStep = 1;
+
+ outerMatrix = matB;
+ innerMatrix = matA;
+
+ outerVectSize = vectorSize;
+ innerVectSize = vnni;
+
+ fmaBound = N / vectorSize;
+ }
+
+ // Load all the element of A or B matrix
+ for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+ Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+ Value valueRow;
+
+ if (isNTileLarge) {
+
+ // With the assumption as batch-reduce matmul initialize reduction, M, and
+ // K dimension.
+ SmallVector<Value> index = {c0, indexOp_i, c0};
+
+ // Remove reduction dimension if it is a batch matmul
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // A Matrix load + broadcast
+ Value row = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(outerVectSize, elementType),
+ outerMatrix, index);
+ valueRow = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+ row);
+ } else {
+
+ // With the assumption as batch-reduce matmul initialize reduction, K, and
+ // N dimension.
+ SmallVector<Value> index = {c0, c0, indexOp_i};
+
+ // Remove reduction dimension if it is a batch matmul
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // B Matrix load.
+ valueRow = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(outerVectSize, elementType),
+ outerMatrix, index);
+ }
+
+ matLoad.push_back(valueRow);
+ }
+
+ // Load elements of A/B Matrix one at a time and compute FMA
+ for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+ Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+ Value valueRow;
+
+ if (!isNTileLarge) {
+ SmallVector<Value> index = {c0, indexOp_j, c0};
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // A Matrix load + broadcast
+ Value row = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(innerVectSize, elementType),
+ innerMatrix, ValueRange(index));
+ valueRow = vector::BroadcastOp::create(
+ rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+ row);
+ } else {
+
+ SmallVector<Value> index = {c0, c0, indexOp_j};
+ if (matmulType == MatMulType::Batch) {
+ index.erase(index.begin());
+ }
+
+ // B Matrix load
+ valueRow = vector::LoadOp::create(
+ rewriter, loc, VectorType::get(innerVectSize, elementType),
+ innerMatrix, index);
+ }
+
+ // FMAs
+ for (unsigned int i = 0; i < fmaBound; i = i + 1) {
+...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/163382
More information about the Mlir-commits
mailing list