[Mlir-commits] [mlir] [MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations (PR #148198)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 11 03:37:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-neon
Author: Momchil Velikov (momchil-velikov)
<details>
<summary>Changes</summary>
This is split in two commits:
* refactor I8MM lowering to make it easier to add ...
* ... BF16 lowering
---
Patch is 66.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148198.diff
12 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+4)
- (modified) mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td (+13-2)
- (modified) mlir/include/mlir/Dialect/ArmNeon/Transforms.h (+2-2)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+3-1)
- (modified) mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp (+6-1)
- (modified) mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt (+1-1)
- (added) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp (+499)
- (removed) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp (-364)
- (modified) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+1-1)
- (added) mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir (+225)
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir (+176)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of Arm FEAT_I8MM instructions while lowering "
"the vector dialect.">,
+ Option<"armBF16", "enable-arm-bf16",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_BF16 instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
index bcaca7da967fa..35747126d3db1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
"apply_patterns.arm_neon.vector_contract_to_i8mm",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that vector.contract operations should be lowered to
- finer-grained vector primitives from the ArmNeon dialect.
+ Indicates that vector contract operations should be lowered to
+ to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
+def ApplyArmNeonContractionToBFMMLAPatternsOp
+ : Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector contract operations should be lowered to
+ to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
}];
let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 2f0f634a96770..08065a3b25266 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -13,8 +13,8 @@ namespace mlir {
class RewritePatternSet;
namespace arm_neon {
-void populateLowerContractionToNeonI8MMPatternPatterns(
- RewritePatternSet &patterns);
+void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
+void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
} // namespace arm_neon
} // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 549d0210af7ad..1045824c437ab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -84,10 +84,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorGatherLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
- arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+ arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
if (armSVE)
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
}
+ if (armBF16 && armNeon)
+ arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
index d07e6a52d8b5f..d069bde6d9979 100644
--- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -20,7 +20,12 @@ using namespace mlir;
void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+ arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
+}
+
+void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index 06bafde451cbb..368dacac7b835 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRArmNeonTransforms
- LowerContractionToNeonI8MMPattern.cpp
+ LowerContractToNeonPatterns.cpp
DEPENDS
MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
new file mode 100644
index 0000000000000..06746daa8075b
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -0,0 +1,499 @@
+//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the Neon FEAT_I8MM extension.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for SVE. See:
+// https://github.com/llvm/llvm-project/issues/145559
+// LowerContractionToSVEI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-neon"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto eltTy = cast<VectorType>(v.getType()).getElementType();
+ if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy)
+ return {};
+ auto inEltTy = inTy.getElementType();
+ if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+ return {};
+
+ return inOp;
+}
+
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+ bool signExt, PatternRewriter &rewriter) {
+ Type targetTy = srcTy.clone(rewriter.getI8Type());
+ return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+ : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
+
+class VectorContractRewriter {
+protected:
+ // Designate the operation (resp. instruction) used to do sub-tile matrix
+ // multiplications.
+ enum class MMLA {
+ Nop,
+ SignedInt, // smmla
+ UnsignedInt, // ummla
+ MixedInt, // usmmla
+ Bfloat // bfmmla
+ };
+
+ // Lower-level operation to be emitted.
+ MMLA mmlaOp = MMLA::Nop;
+
+ // Indicate if the operands for the ArmNeon dialect operation need to be
+ // swapped. Currently this is needed in order to emulate an "summla"
+ // operation.
+ bool swapOperands = false;
+
+ // The operand tiles. These are not necessarily the operands of
+ // `vector.contract`, for example they could be operands to `arith.extsi`
+ // that is in turn fed into `vector.contract`.
+ Value lhs;
+ Value rhs;
+ Value acc;
+
+ // The dimensions logically corresponding to matrix multiplication of
+ // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+ // shapes, for example RHS could be NxK with a transposing indexing map.
+ int64_t dimM = 0;
+ int64_t dimN = 0;
+ int64_t dimK = 0;
+
+ // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+ SmallVector<int64_t> iterationBounds;
+
+ // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+ // of this shape.
+ SmallVector<int64_t> subTileShape;
+
+ // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+ Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+ Value lhs, Value rhs) {
+
+ if (swapOperands)
+ std::swap(lhs, rhs);
+ switch (mmlaOp) {
+ case MMLA::SignedInt:
+ return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::UnsignedInt:
+ return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::MixedInt:
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+ lhs, rhs);
+ case MMLA::Bfloat:
+ return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
+ rhs);
+ case MMLA::Nop:
+ llvm_unreachable("Uninitialized operation type");
+ }
+ }
+
+ // Check common preconditions for applying the patterns and initialize
+ // logical dimensions.
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ // Check iterator types for matrix multiplication.
+ SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+ if (!((itTypes.size() == 3 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::parallel &&
+ itTypes[2] == vector::IteratorType::reduction)) ||
+ (itTypes.size() == 2 &&
+ (itTypes[0] == vector::IteratorType::parallel &&
+ itTypes[1] == vector::IteratorType::reduction))))
+ return rewriter.notifyMatchFailure(
+ op, "iterator types do not correspond to matrix multiplication");
+
+ // Avoid 0-D vectors and 1-D rhs:
+ VectorType lhsType = op.getLhsType();
+ VectorType rhsType = op.getRhsType();
+ if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+ rhsType.getRank() != 2)
+ return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
+ // This codegen does not work for scalable vectors. Return failure so this
+ // pattern is not accidentally chosen over patterns that lower to ArmSVE.
+ if (lhsType.isScalable() || rhsType.isScalable())
+ return rewriter.notifyMatchFailure(op,
+ "Not applicable to scalable vectors");
+
+ // Initialize dimensions and check for a matching K dimension.
+ dimM = lhsType.getDimSize(0);
+ dimN = rhsType.getDimSize(0);
+ dimK = rhsType.getDimSize(1);
+
+ int64_t lhsDimK;
+ if (lhsType.getRank() == 1) {
+ dimM = 1;
+ lhsDimK = lhsType.getDimSize(0);
+ } else {
+ lhsDimK = lhsType.getDimSize(1);
+ }
+
+ if (lhsDimK != dimK)
+ return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
+
+ return success();
+ }
+
+public:
+ void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+ // Create some convenience types.
+ auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+ auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+ auto inputExpandedType =
+ VectorType::get({2, subTileShape.back()}, inputElementType);
+ auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+ // One-dimensional representation of logical sub-tiles as required by the
+ // ArmNeon ops.
+ auto collapsedInputType =
+ VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+ auto collapsedOutputType =
+ VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+ // Get indexing maps for a more concise/convenient access.
+ auto indexingMaps = op.getIndexingMapsArray();
+ AffineMap &lhsPermutationMap = indexingMaps[0];
+ AffineMap &rhsPermutationMap = indexingMaps[1];
+ AffineMap &accPermutationMap = indexingMaps[2];
+
+ Location loc = op.getLoc();
+
+ // Initial accumulator for the final result. This is the un-tiled result if
+ // tiling is done.
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+ SmallVector<int64_t, 3> loopOrder = {0, 1};
+ if (iterationBounds.size() == 3)
+ loopOrder.push_back(2);
+
+ // Keep track of the previous accumulator when tiling over K.
+ Value kAcc;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
+ // Helper to compute the new shape of each operand and extract the slice.
+ auto extractOperand = [&](Value operand, AffineMap permutationMap,
+ ArrayRef<int64_t> operandOffsets) {
+ SmallVector<int64_t> operandShape = applyPermutationMap(
+ permutationMap, ArrayRef<int64_t>(subTileShape));
+ SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+ return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand, operandOffsets, operandShape, operandStrides);
+ };
+
+ // Extract tiled lhs, rhs, and acc
+ SmallVector<int64_t> lhsOffsets =
+ applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+ Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
+ SmallVector<int64_t> rhsOffsets =
+ applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+ Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
+ SmallVector<int64_t> accOffsets =
+ applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+ Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
+
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+ // rows along dimM. Expand their shapes to match the ArmNeon op.
+ if (dimM == 1) {
+ auto expandRowVector = [&](Value tiledOperand,
+ VectorType expandedTypeType) {
+ auto emptyOperand = rewriter.create<arith::ConstantOp>(
+ loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
+ SmallVector<int64_t> offsets(
+ cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
+ SmallVector<int64_t> strides(
+ cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
+ return rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tiledOperand, emptyOperand, offsets, strides);
+ };
+ tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+ tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
+ }
+
+ // Transpose ACC if doing signed by unsigned multiplication, because we're
+ // using the instruction for unsigned by signed multiplication with
+ // reversed operands.
+ if (swapOperands)
+ tiledAcc = rewriter.create<vector::TransposeOp>(
+ loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
+
+ // Collapse tiled operands to 1D vectors required by the ArmNeon ops
+ auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledLhs.getLoc(), collapsedInputType, tiledLhs);
+ auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+
+ bool initialKAcc = offsets.back() == 0;
+ Value collapsedRes;
+ if (!initialKAcc) {
+ collapsedRes = kAcc;
+ } else {
+ collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+ tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+ }
+
+ // Insert contract op
+ kAcc =
+ createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
+
+ // Reshape output back to 2D
+ Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
+ kAcc.getLoc(), tiledAcc.getType(), kAcc);
+
+ // Because of the reversed operands the result is obtained transposed.
+ // Transpose it back,
+ if (swapOperands)
+ tiledRes = rewriter.create<vector::TransposeOp>(
+ loc, tiledRes, ArrayRef<int64_t>({1, 0}));
+
+ // With vecmat, only one row of tiled ACC can be inserted into the final
+ // result
+ if (dimM == 1)
+ tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+
+ // Insert the tiled result back into the non tiled result of the
+ // contract op.
+ SmallVector<int64_t> strides(
+ cast<ShapedType>(tiledRes.getType()).getRank(), 1);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tiledRes, result, accOffsets, strides);
+ }
+
+ rewriter.replaceOp(op, result);
+ }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+ LogicalResult matchAndInit(vector::ContractionOp op,
+ PatternRewriter &rewriter) {
+ if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+ return failure();
+
+ // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+ // tiling.
+ if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+ return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+ // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+ // values before the extension. All four signed/unsigned combinations for
+ // input operands are supported, but they are lowered to different
+ // operations. Determine which is the appropriate operation to lower to.
+ mmlaOp = MMLA::SignedInt;
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::UnsignedInt;
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+ }
+ if (!maybeLhs)
+ return rewriter.notifyMatc...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/148198
More information about the Mlir-commits
mailing list