[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16 operations (PR #147052)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 4 06:26:54 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-neon
Author: Momchil Velikov (momchil-velikov)
<details>
<summary>Changes</summary>
This patch adds lowering of Bfloat16 widening matrix multiply and accumulate `vector.contract`, by parametrising and refactoring the pattern for 8-bit integers.
---
Patch is 65.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147052.diff
12 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+4)
- (modified) mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td (+11-1)
- (modified) mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h (+2)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+3)
- (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp (+1-1)
- (modified) mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp (+6-1)
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+1-1)
- (added) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp (+607)
- (removed) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (-366)
- (added) mlir/test/Dialect/Vector/CPU/ArmSVE/vector-bfmmla.mlir (+105)
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-bfmmla.mlir (+201)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/vector-contract-i8mm.mlir (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_I8MM instructions while lowering "
"the vector dialect.">,
+ Option<"armBF16", "enable-arm-bf16",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_BF16 instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
index 53784982be6dc..81b3c736b93f3 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.td
@@ -12,7 +12,7 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
-def ApplyArmSVELowerContractionPatternsOp
+def ApplyArmSVELowerContractionToI8MMPatternsOp
: Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
@@ -23,4 +23,14 @@ def ApplyArmSVELowerContractionPatternsOp
let assemblyFormat = "attr-dict";
}
+def ApplyArmSVELowerContractionToBFMMLAPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.arm_sve.vector_contract_to_bfmmla",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contraction-like operations should be lowered to
+ finer-grained vector primitives using the ArmSVE dialect.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
#endif // ARMSVE_VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 232e2be29e574..de160dbf8ed94 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -23,6 +23,8 @@ void populateArmSVELegalizeForLLVMExportPatterns(
void populateLowerContractionToSVEI8MMPatternPatterns(
RewritePatternSet &patterns);
+void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 67c0eca15638a..4d74aabcaa50d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -89,6 +89,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+ if (armBF16)
+ populateLowerContractionToSVEBFMMLAPatterns(patterns);
+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
index 7180884c77e98..a95fc51d562c2 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp
@@ -12,7 +12,7 @@
// TODO: There may be opportunities to unify this with a similar pattern
// for SVE. See:
// https://github.com/llvm/llvm-project/issues/145559
-// LowerContractionToSVEI8MMPattern.cpp
+// LowerContractToSVEPatterns.cpp
//
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
index b2ca4fc1eaa8c..8572c34c8b12b 100644
--- a/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp
@@ -18,11 +18,16 @@ using namespace mlir;
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
-void transform::ApplyArmSVELowerContractionPatternsOp::populatePatterns(
+void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ mlir::populateLowerContractionToSVEBFMMLAPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 65f98b44b1b69..c29eaca244b4a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
- LowerContractionToSVEI8MMPattern.cpp
+ LowerContractToSVEPatterns.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
new file mode 100644
index 0000000000000..2987287afe9cd
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
@@ -0,0 +1,607 @@
+//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for Neon. See:
+// https://github.com/llvm/llvm-project/issues/145559
+// LowerContractionToNeonI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+
+namespace {
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `i8` to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto vTy = cast<VectorType>(v.getType());
+ if (!vTy.getElementType().isSignlessInteger(8))
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `i8` to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || !inTy.getElementType().isSignlessInteger(8))
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || !outTy.getElementType().isSignlessInteger(32))
+ return {};
+
+ return inOp;
+}
+
+/// This class encapsulates the algorithm and parametrisation (in terms of types
+/// and dimensions) of lowering a `vector.contract` to "primitive" matrix
+/// multiplication operations of the SVE dialect (here "primitive" would mean
+/// corresponding to a single target instruction).
+///
+/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
+/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
+/// for concreteness the description below is given for `smmla`.
+///
+/// The lowering triggers for a contraction operation that performs a matrix
+/// multiply of two 8-bit integer matrix tiles with logical dimensions
+/// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
+/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
+/// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
+///
+/// The operands' shapes are such that the operands can be evenly split into
+/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
+/// instructions. The intent is that M and N are chosen (by higher level
+/// transforms) in such a way as to maximise register usage. The main use case
+/// we envision as of now is MMT4D, thus the RHS operand is expected
+/// pre-transposed.
+///
+/// The matrix multiplication is performed by unrolling the usual tiled matrix
+/// multiplication algorithm using sub-tiles with dimensions <2x8> for the
+/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
+/// accumulator.
+///
+/// One way to illustrate the operation is as follows:
+///
+/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
+/// +-----------------------------
+/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+/// ... | ... ... ... ...
+/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
+///
+/// The RHS operand is unpacked into N/2 values, each representing a sequence
+/// of VSCALE number of sub-tiles with dimensions <8x2>.
+/// The LHS operand is initially unpacked into M/2 values, each representing a
+/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
+/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
+/// RHS sub-tile correctly computes an entire result sub-tile.
+/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
+/// (in memory or when imposing a row-major layout on the 2D vector value).
+/// Reading the ACC is implemented as reading two consecutive rows and
+/// interleaving the by pairs to obtain a vector having length twice the length
+/// of an ACC row. This vector now is a sequence of one-dimensional tiles with
+/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
+/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
+/// a0 a1 b0 b1
+/// a2 a3 b2 b3
+/// we read the two rows as separate values and then interleave by pairs
+/// to obtain
+/// a0 a1 a2 a3 b0 b1 b2 b3
+/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
+///
+/// Writing the OUT tile is done by the reverse of the above procedure,
+/// concatenate two "flattened" sub-tiles into
+/// c0 c1 c2 c3 d0 d1 d2 d3
+/// deinterleave by pairs to obtain as separate values
+/// c0 c1 d0 d1
+/// c2 c3 d2 d3
+/// which are then inserted into the final result.
+///
+/// Multiplication of a signed LHS by an unsigned LHS is performed by
+/// swapping the order of the operands and emitting an `usmmla` (since there
+/// isn't an `summla` instruction). Therefore each ACC sub-tile needs
+/// to be transposed before the addition and the sum, an OUT sub-tile,
+/// needs to be transposed before insertion into the final result.
+/// This is done very elegantly by a modification of the above to
+/// interleave/deinterleave not by pairs, but by individual elements, e.g.
+/// after ordinary interleave we obtain
+/// a0 a2 a1 a3 b0 b2 b1 b3
+/// which is exactly the desired layout of having each individual 2x2 tile
+/// transposed.
+///
+/// All of the above readily applies to FEAT_BF16 `bfmmla` with the
+/// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
+/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
+/// for the integer case).
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // The operand tiles. These are not necessarily the operends of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // Conventional names for matrix dimensions.
+ int64_t M = 0;
+ int64_t N = 0;
+ int64_t K = 0;
+
+ // Single-dimensional vector types for the operands of the ArmSVE dialect
+ // op.
+ VectorType flatLhsType;
+ VectorType flatRhsType;
+ VectorType flatAccType;
+
+ // Single-dimension vector type for the entire RHS tile.
+ VectorType flatRhsTileType;
+
+ // Vector type having the same number of elements as a row in the
+ // accumulator/output tile and the same element type.
+ VectorType accRowTy;
+
+ // Vector type having twice the number of elements as a row in the
+ // accumulator/output tile the same element type.
+ VectorType accRowX2Ty;
+
+ // Vector type having half the number of elements as a row in the
+ // accumulator/output tile and an integer element type with twice the bit
+ // width.
+ VectorType accRow64Ty;
+ VectorType accRowX264Ty;
+
+ // Indicate if the operands for the ArmSVE dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
+ // Create the matrix mulitply and accumulate operation according to
+ // `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs);
+
+ // Check general preconditions for applying the transformation, common to the
+ // integer and the bfloat16 case.
+ LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
+
+public:
+ VectorContractRewriter() = default;
+
+ // Do the actuall rewrite. This member function is shared by both integer and
+ // bfloat16 rewrites.
+ Value rewrite(vector::ContractionOp op, PatternRewriter &rewriter);
+};
+
+Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
+ Location loc, Value acc, Value lhs,
+ Value rhs) {
+ if (swapOperands)
+ std::swap(lhs, rhs);
+
+ switch (mmlaOp) {
+ case MMLA::SignedInt:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::UnsignedInt:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::MixedInt:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_sve::BfmmlaOp>(loc, flatAccType, acc, lhs, rhs);
+ default:
+ llvm_unreachable("Uninitialized operation kind");
+ }
+}
+
+LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // This corresponds to matrix multiplication with transposed RHS.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
+ 3, ArrayRef{0u, 1u}, op.getContext()))
+ return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
+
+ return success();
+}
+
+Value VectorContractRewriter::rewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+
+ // Extract LHS sub-tiles with logical shape <2xK>.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the LHS tile.
+ auto r0 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i});
+ auto r1 =
+ rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1});
+ // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
+ SmallVector<int64_t> shuffleIdx(2 * K);
+ std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
+ auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx);
+ // Turn it into a scalable vector.
+ auto s = rewriter.create<vector::ScalableInsertOp>(
+ loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0);
+ // Replicate the sub-tile VSCALE times to fill the entire vector.
+ auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0);
+ lhsTile.push_back(r);
+ }
+
+ // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
+ auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(),
+ flatRhsTileType, this->rhs);
+
+ // Extract the RHS sub-tiles with logical shape <Kx[2]>.
+ SmallVector<Value> rhsTile;
+ for (int64_t j = 0; j < N; j += 2)
+ rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>(
+ loc, flatRhsType, rhs, j * K));
+
+ // Extract and pack the ACC sub-tiles.
+ SmallVector<Value> accTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Extract two consecutive rows of the accumulator tile.
+ auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i});
+ auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(),
+ ArrayRef<int64_t>{i + 1});
+ Value accTileVec;
+ if (swapOperands) {
+ // We are performing the operation with swapped LHS and RHS we need to
+ // transpose e...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/147052
More information about the Mlir-commits
mailing list