[Mlir-commits] [mlir] [MLIR][AArch64] Simplify LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC) (PR #144909)
Momchil Velikov
llvmlistbot at llvm.org
Thu Jun 19 08:14:18 PDT 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/144909
Just recently learned about `isSignlessInteger`, use that instead of comparing to types obtained via `rewriter.getI<N>Type()`.
It also make it more similar to the version in https://github.com/llvm/llvm-project/pull/144698
>From 7c5c8eecfcce07d5552efc2d5b00f0e468ff7e8b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 19 Jun 2025 15:03:30 +0000
Subject: [PATCH] [MLIR][AArch64] Simplify
LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC)
Just recently learned about `isSignlessInteger`, use that instead
of comparing to types obtained via `rewriter.getI<N>Type()`.
---
.../LowerContractionToSVEI8MMPattern.cpp | 20 ++++++++-----------
1 file changed, 8 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index b1233c5c06eb4..a1209fe8230e2 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -39,7 +39,7 @@ namespace {
//
// Return success only for extensions from `i8` to `i32`.
template <typename Op>
-std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
+std::optional<Value> getExtOperand(Value v) {
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
"Must be instantiated with either sign- or zero- extension op");
@@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto vTy = cast<VectorType>(v.getType());
- if (vTy.getElementType() != i8Ty)
+ if (!vTy.getElementType().isSignlessInteger(8))
return {};
return v;
}
@@ -61,11 +61,11 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
// operation type, check it's extended from `i8` to `i32`.
auto inOp = extOp.getIn();
auto inTy = dyn_cast<VectorType>(inOp.getType());
- if (!inTy || inTy.getElementType() != i8Ty)
+ if (!inTy || !inTy.getElementType().isSignlessInteger(8))
return {};
auto outTy = dyn_cast<VectorType>(extOp.getType());
- if (!outTy || outTy.getElementType() != i32Ty)
+ if (!outTy || !outTy.getElementType().isSignlessInteger(32))
return {};
return inOp;
@@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern
// operands are supported, but they are lowered to different operations.
// Determine which is the appropriate operation to lower to.
MMLA mmlaOp = MMLA::Signed;
- auto maybeLhs = getExtOperand<arith::ExtSIOp>(
- op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
if (!maybeLhs) {
mmlaOp = MMLA::Unsigned;
- maybeLhs = getExtOperand<arith::ExtUIOp>(
- op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
}
if (!maybeLhs)
return rewriter.notifyMatchFailure(
op, "LHS is not a sign- or zero- extended i8");
- auto maybeRhs = getExtOperand<arith::ExtSIOp>(
- op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
if (maybeRhs) {
if (mmlaOp == MMLA::Unsigned)
mmlaOp = MMLA::Mixed;
} else {
if (mmlaOp == MMLA::Signed)
mmlaOp = MMLA::MixedSwapped;
- maybeRhs = getExtOperand<arith::ExtUIOp>(
- op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
}
if (!maybeRhs)
return rewriter.notifyMatchFailure(
More information about the Mlir-commits
mailing list