[Mlir-commits] [mlir] [mlir][x86vector] Lower vector.contract to sequence of FMAs (PR #163382)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 27 04:00:13 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Arun Thangamani (arun-thmn)

<details>
<summary>Changes</summary>

This PR lowers `vector.contract` (only `batch-reduce` or `batch` matmul) to a sequence of `FMAs` with respect to the `vector_size`.

---

Patch is 48.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163382.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt (+2) 
- (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt (+4) 
- (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h (+31) 
- (added) mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td (+37) 
- (modified) mlir/include/mlir/Dialect/X86Vector/Transforms.h (+12) 
- (modified) mlir/lib/Dialect/X86Vector/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt (+17) 
- (added) mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp (+61) 
- (modified) mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp (+659) 
- (modified) mlir/lib/RegisterAllExtensions.cpp (+2) 
- (added) mlir/test/Dialect/X86Vector/vector-contract-to-nanokernels.mlir (+215) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
index 0fe01824b8248..bbe8e4eb892dd 100644
--- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
@@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
 
 add_mlir_interface(X86VectorInterfaces)
 add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
+
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..6f377e10fa8f8
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
+mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
+add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
new file mode 100644
index 0000000000000..e1d8b8762e799
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h
@@ -0,0 +1,31 @@
+//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpImplementation.h"
+
+//===----------------------------------------------------------------------===//
+// X86Vector Transform Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace x86vector {
+void registerTransformDialectExtension(DialectRegistry &registry);
+
+} // namespace x86vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..9db2b36a2a8aa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -0,0 +1,37 @@
+//===- X86VectorTransformOps.td - X86Vector transform ops ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef X86VECTOR_TRANSFORM_OPS
+#define X86VECTOR_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/IR/RegionKindInterface.td"
+
+def ApplyVectorContractNanokernelLoweringPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.x86vector.vector_contract_nanokernel_lowering",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operation can be lowered to target
+    specific nanokernels.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<I64Attr, "8">:$vector_size);
+
+  let assemblyFormat = [{
+    (`vector_size` `=` $vector_size^)? attr-dict
+  }];
+
+}
+
+
+#endif // X86VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
index d54111ca41e69..6410c12265f12 100644
--- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h
+++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h
@@ -11,6 +11,10 @@
 
 #include "mlir/IR/Value.h"
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+
 namespace mlir {
 
 class ImplicitLocOpBuilder;
@@ -79,6 +83,14 @@ struct MaskHelper {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Transforms a scheduled pattern to lower a tiled batch or batch-reduce
+// vector contraction into a sequence of nanokernels.
+// The transformation is tailored to the target machine architecture
+// and guided by the user-specified vector size.
+void populateVectorContractNanokernelLoweringPatterns(
+    RewritePatternSet &patterns, std::optional<unsigned> vectorSize = 8);
+
 //===----------------------------------------------------------------------===//
 /// Helpers extracted from:
 ///   - clang/lib/Headers/avxintrin.h
diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/lib/Dialect/X86Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..f4c9f8a05acbc
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRX86VectorTransformOps
+  X86VectorTransformOps.cpp
+
+  DEPENDS
+  MLIRX86VectorTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRSideEffectInterfaces
+  MLIRTransformDialect
+  MLIRTransformDialectUtils
+  MLIRX86VectorDialect
+  MLIRX86VectorTransforms
+  )
diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
new file mode 100644
index 0000000000000..e003e3ad7cd08
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp
@@ -0,0 +1,61 @@
+//===- X86VectorTransformOps.cpp - Implementation of Vector transform ops
+//--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+
+using namespace mlir;
+using namespace mlir::x86vector;
+using namespace mlir::transform;
+
+void mlir::transform::ApplyVectorContractNanokernelLoweringPatternsOp::
+    populatePatterns(RewritePatternSet &patterns) {
+  x86vector::populateVectorContractNanokernelLoweringPatterns(patterns,
+                                                              getVectorSize());
+}
+
+//===----------------------------------------------------------------------===//
+// Transform op registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+class X86VectorTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          X86VectorTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      X86VectorTransformDialectExtension)
+
+  X86VectorTransformDialectExtension() {
+    declareGeneratedDialect<x86vector::X86VectorDialect>();
+    declareGeneratedDialect<LLVM::LLVMDialect>();
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+        >();
+  }
+};
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
+
+void mlir::x86vector::registerTransformDialectExtension(
+    DialectRegistry &registry) {
+  registry.addExtensions<X86VectorTransformDialectExtension>();
+}
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
index c51266afe9e8f..da377763331f2 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRX86VectorTransforms
   AVXTranspose.cpp
   LegalizeForLLVMExport.cpp
+  NanoKernels.cpp
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
new file mode 100644
index 0000000000000..4d0906a2ec057
--- /dev/null
+++ b/mlir/lib/Dialect/X86Vector/Transforms/NanoKernels.cpp
@@ -0,0 +1,659 @@
+//===- NanoKernels.cpp - Lower matmul to Nanokernels -- -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements matmul rewrites as nanokernels with respect to target
+// machine for FP32 (for selective batch or batch-reduce matmul patterns) and
+// BF16 (TODO) types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+using namespace mlir::x86vector;
+
+// Enum to represent the type of matmul operation
+enum class MatMulType { Batch, BatchReduce, Others };
+
+static FailureOr<SmallVector<scf::ForOp>>
+getTiledMatmulLoopNest(vector::ContractionOp contractOp,
+                       MatMulType matmulType) {
+  SmallVector<scf::ForOp> list;
+  Operation *current = contractOp;
+  unsigned int dimCount = matmulType == MatMulType::BatchReduce ? 4 : 3;
+
+  // It is register tiled loop structure on batch (or reduce) matmul
+  // (M->N->(reduce)->K).
+  for (unsigned int i = 0; i < dimCount; i++) {
+    Operation *parent = current->getParentOfType<scf::ForOp>();
+    if (!parent)
+      return failure();
+    list.push_back(dyn_cast<scf::ForOp>(parent));
+    current = parent;
+  }
+  return list;
+}
+
+static LogicalResult checkMatmulLoopAndSubviewOffsetsMatching(
+    SmallVector<scf::ForOp> loops, SmallVector<memref::SubViewOp> subviews,
+    MatMulType matmulType) {
+  auto subviewOpLhsOffsets = subviews[0].getOffsets();
+  auto subviewOpRhsOffsets = subviews[1].getOffsets();
+  auto subviewOpAccOffsets = subviews[2].getOffsets();
+
+  if (matmulType == MatMulType::BatchReduce) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[2] || ivK != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivReduction = loops[1].getInductionVar();
+    if (ivReduction != subviewOpLhsOffsets[0] ||
+        ivReduction != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[2].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[2])
+      return failure();
+
+    Value ivM = loops[3].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[1] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  if (matmulType == MatMulType::Batch) {
+    Value ivK = loops[0].getInductionVar();
+    if (ivK != subviewOpLhsOffsets[1] || ivK != subviewOpRhsOffsets[0])
+      return failure();
+
+    Value ivN = loops[1].getInductionVar();
+    if (ivN != subviewOpAccOffsets[1] || ivN != subviewOpRhsOffsets[1])
+      return failure();
+
+    Value ivM = loops[2].getInductionVar();
+    if (ivM != subviewOpLhsOffsets[0] || ivM != subviewOpAccOffsets[0])
+      return failure();
+  }
+
+  return success();
+}
+
+static SmallVector<Value>
+loadAccumulatorBeforeGEMM(Location loc, RewriterBase &rewriter,
+                          Type elementType, unsigned int M, unsigned int N,
+                          unsigned int vectorSize, Value subviewOpAcc) {
+
+  SmallVector<Value> accumulators;
+
+  // Initialize local variable on assumption that M tile is larger than N
+  unsigned int outerBound = M;
+  unsigned int innerBound = N;
+
+  unsigned int outerStep = 1;
+  unsigned int innerStep = vectorSize;
+
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+  }
+
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    for (unsigned int j = 0; j < innerBound; j = j + innerStep) {
+      Value indexOp_A = arith::ConstantIndexOp::create(rewriter, loc, i);
+      Value indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, j);
+
+      if (isNTileLarge) {
+        indexOp_A = indexOp_B;
+        indexOp_B = arith::ConstantIndexOp::create(rewriter, loc, i);
+      }
+
+      auto valueCRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(vectorSize, elementType), subviewOpAcc,
+          ValueRange{indexOp_A, indexOp_B});
+      accumulators.push_back(valueCRow);
+    }
+  }
+
+  return accumulators;
+}
+
+// This function takes matrices A, B, and C (represented as vectors)
+// and generates equivalent target-specific nanokernels.
+// It returns the final accumulator as output.
+// Based on the M tile, N tile, and vector size, it generates optimized
+// nanokernels under the condition that the reduction and K dimension
+// of the input matrices are equal to 1.
+//
+// Input: Matrix A, Matrix B, Accmulator as M*(N/vector size) vectors, M tile
+// size, N tile size, Vector size.
+//
+// Output:
+// case i: M >= (N/vector size). For example, M=3; N=32; vector size = 16.
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A0, load_B1, i/p_Acc[1]
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A1, load_B0, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A2, load_B0, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A2, load_B1, i/p_Acc[5]
+//
+// case ii: M < (N/vector size). For example, M=2; N=48; vector size = 16.
+//  bcst_A0 = load A[0] and broadcast it into vector<16xf32>
+//  bcst_A1 = load A[1] and broadcast it into vector<16xf32>
+//  bcst_A2 = load A[2] and broadcast it into vector<16xf32>
+//  load_B0 = load B[0-15] into vector<16xf32>
+//  o/p_Acc[0] = vector.fma bcst_A0, load_B0, i/p_Acc[0]
+//  o/p_Acc[1] = vector.fma bcst_A1, load_B0, i/p_Acc[1]
+//  load_B1 = load B[16-31] into vector<16xf32>
+//  o/p_Acc[2] = vector.fma bcst_A0, load_B1, i/p_Acc[2]
+//  o/p_Acc[3] = vector.fma bcst_A1, load_B1, i/p_Acc[3]
+//  load_B2 = load B[32-47] into vector<16xf32>
+//  o/p_Acc[4] = vector.fma bcst_A0, load_B2, i/p_Acc[4]
+//  o/p_Acc[5] = vector.fma bcst_A1, load_B2, i/p_Acc[5]
+//
+// return o/p_Acc;
+SmallVector<Value>
+generateNanokernels(RewriterBase &rewriter, Location loc, Type elementType,
+                    unsigned int vectorSize, unsigned int vnni, unsigned int M,
+                    unsigned int N, ValueRange acc, Value matA, Value matB,
+                    MatMulType matmulType) {
+
+  SmallVector<Value> accumulators;
+  SmallVector<Value> matLoad;
+  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
+
+  // Start with assumption that M tile size is smaller and create  the
+  // helper variables
+  unsigned int outerBound = M;
+  unsigned int outerStep = 1;
+
+  unsigned int innerBound = N;
+  unsigned int innerStep = vectorSize;
+
+  Value outerMatrix = matA;
+  Value innerMatrix = matB;
+
+  unsigned int outerVectSize = vnni;
+  unsigned int innerVectSize = vectorSize;
+
+  unsigned int fmaBound = M;
+
+  // update helper variables if N tile size is smaller
+  bool isNTileLarge = (N / vectorSize) > M;
+  if (!isNTileLarge) {
+    outerBound = N;
+    innerBound = M;
+
+    outerStep = vectorSize;
+    innerStep = 1;
+
+    outerMatrix = matB;
+    innerMatrix = matA;
+
+    outerVectSize = vectorSize;
+    innerVectSize = vnni;
+
+    fmaBound = N / vectorSize;
+  }
+
+  // Load all the element of A or B matrix
+  for (unsigned int i = 0; i < outerBound; i = i + outerStep) {
+    Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
+    Value valueRow;
+
+    if (isNTileLarge) {
+
+      // With the assumption as batch-reduce matmul initialize reduction, M, and
+      // K dimension.
+      SmallVector<Value> index = {c0, indexOp_i, c0};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      // With the assumption as batch-reduce matmul initialize reduction, K, and
+      // N dimension.
+      SmallVector<Value> index = {c0, c0, indexOp_i};
+
+      // Remove reduction dimension if it is a batch matmul
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load.
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(outerVectSize, elementType),
+          outerMatrix, index);
+    }
+
+    matLoad.push_back(valueRow);
+  }
+
+  // Load elements of A/B Matrix one at a time and compute FMA
+  for (unsigned int j = 0, k = 0; j < innerBound; j = j + innerStep) {
+    Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
+    Value valueRow;
+
+    if (!isNTileLarge) {
+      SmallVector<Value> index = {c0, indexOp_j, c0};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // A Matrix load + broadcast
+      Value row = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, ValueRange(index));
+      valueRow = vector::BroadcastOp::create(
+          rewriter, loc, VectorType::get(vectorSize, rewriter.getF32Type()),
+          row);
+    } else {
+
+      SmallVector<Value> index = {c0, c0, indexOp_j};
+      if (matmulType == MatMulType::Batch) {
+        index.erase(index.begin());
+      }
+
+      // B Matrix load
+      valueRow = vector::LoadOp::create(
+          rewriter, loc, VectorType::get(innerVectSize, elementType),
+          innerMatrix, index);
+    }
+
+    // FMAs
+    for (unsigned int i = 0; i < fmaBound; i = i + 1) {
+...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/163382


More information about the Mlir-commits mailing list