[Mlir-commits] [mlir] b31413a - [MLIR][AArch64] Simplify LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC) (#144909)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 01:38:05 PDT 2025
Author: Momchil Velikov
Date: 2025-06-23T09:38:01+01:00
New Revision: b31413a96603cd904281368b6f5f8e36836a7cac
URL: https://github.com/llvm/llvm-project/commit/b31413a96603cd904281368b6f5f8e36836a7cac
DIFF: https://github.com/llvm/llvm-project/commit/b31413a96603cd904281368b6f5f8e36836a7cac.diff
LOG: [MLIR][AArch64] Simplify LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC) (#144909)
Just recently learned about `isSignlessInteger`, use that instead of
comparing to types obtained via `rewriter.getI<N>Type()`.
It also makes it closer to a similar function in
`LowerContractionToNeonI8MMPattern.cpp` (formerly `LowerContractionToSMMLAPattern.cpp`)
which would help a potential effort to unify these patterns.
Added:
Modified:
mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index b1233c5c06eb4..a1209fe8230e2 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -39,7 +39,7 @@ namespace {
//
// Return success only for extensions from `i8` to `i32`.
template <typename Op>
-std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
+std::optional<Value> getExtOperand(Value v) {
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
"Must be instantiated with either sign- or zero- extension op");
@@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
if (!extOp) {
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
auto vTy = cast<VectorType>(v.getType());
- if (vTy.getElementType() != i8Ty)
+ if (!vTy.getElementType().isSignlessInteger(8))
return {};
return v;
}
@@ -61,11 +61,11 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
// operation type, check it's extended from `i8` to `i32`.
auto inOp = extOp.getIn();
auto inTy = dyn_cast<VectorType>(inOp.getType());
- if (!inTy || inTy.getElementType() != i8Ty)
+ if (!inTy || !inTy.getElementType().isSignlessInteger(8))
return {};
auto outTy = dyn_cast<VectorType>(extOp.getType());
- if (!outTy || outTy.getElementType() != i32Ty)
+ if (!outTy || !outTy.getElementType().isSignlessInteger(32))
return {};
return inOp;
@@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern
// operands are supported, but they are lowered to
diff erent operations.
// Determine which is the appropriate operation to lower to.
MMLA mmlaOp = MMLA::Signed;
- auto maybeLhs = getExtOperand<arith::ExtSIOp>(
- op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
if (!maybeLhs) {
mmlaOp = MMLA::Unsigned;
- maybeLhs = getExtOperand<arith::ExtUIOp>(
- op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
}
if (!maybeLhs)
return rewriter.notifyMatchFailure(
op, "LHS is not a sign- or zero- extended i8");
- auto maybeRhs = getExtOperand<arith::ExtSIOp>(
- op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
if (maybeRhs) {
if (mmlaOp == MMLA::Unsigned)
mmlaOp = MMLA::Mixed;
} else {
if (mmlaOp == MMLA::Signed)
mmlaOp = MMLA::MixedSwapped;
- maybeRhs = getExtOperand<arith::ExtUIOp>(
- op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
}
if (!maybeRhs)
return rewriter.notifyMatchFailure(
More information about the Mlir-commits
mailing list