[Mlir-commits] [mlir] b31413a - [MLIR][AArch64] Simplify LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC) (#144909)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 23 01:38:05 PDT 2025


Author: Momchil Velikov
Date: 2025-06-23T09:38:01+01:00
New Revision: b31413a96603cd904281368b6f5f8e36836a7cac

URL: https://github.com/llvm/llvm-project/commit/b31413a96603cd904281368b6f5f8e36836a7cac
DIFF: https://github.com/llvm/llvm-project/commit/b31413a96603cd904281368b6f5f8e36836a7cac.diff

LOG: [MLIR][AArch64] Simplify LowerContractionToSVEI8MMPattern.cpp:getExtOperand (NFC) (#144909)

Just recently learned about `isSignlessInteger`, use that instead of
comparing to types obtained via `rewriter.getI<N>Type()`.
It also makes it closer to a similar function in 
`LowerContractionToNeonI8MMPattern.cpp` (formerly `LowerContractionToSMMLAPattern.cpp`)
which would help a potential effort to unify these patterns.

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index b1233c5c06eb4..a1209fe8230e2 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -39,7 +39,7 @@ namespace {
 //
 // Return success only for extensions from `i8` to `i32`.
 template <typename Op>
-std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
+std::optional<Value> getExtOperand(Value v) {
 
   static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
                 "Must be instantiated with either sign- or zero- extension op");
@@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
   if (!extOp) {
     if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
       auto vTy = cast<VectorType>(v.getType());
-      if (vTy.getElementType() != i8Ty)
+      if (!vTy.getElementType().isSignlessInteger(8))
         return {};
       return v;
     }
@@ -61,11 +61,11 @@ std::optional<Value> getExtOperand(Value v, Type i8Ty, Type i32Ty) {
   // operation type, check it's extended from `i8` to `i32`.
   auto inOp = extOp.getIn();
   auto inTy = dyn_cast<VectorType>(inOp.getType());
-  if (!inTy || inTy.getElementType() != i8Ty)
+  if (!inTy || !inTy.getElementType().isSignlessInteger(8))
     return {};
 
   auto outTy = dyn_cast<VectorType>(extOp.getType());
-  if (!outTy || outTy.getElementType() != i32Ty)
+  if (!outTy || !outTy.getElementType().isSignlessInteger(32))
     return {};
 
   return inOp;
@@ -199,27 +199,23 @@ class LowerContractionToSVEI8MMPattern
     // operands are supported, but they are lowered to 
diff erent operations.
     // Determine which is the appropriate operation to lower to.
     MMLA mmlaOp = MMLA::Signed;
-    auto maybeLhs = getExtOperand<arith::ExtSIOp>(
-        op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
     if (!maybeLhs) {
       mmlaOp = MMLA::Unsigned;
-      maybeLhs = getExtOperand<arith::ExtUIOp>(
-          op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
     }
     if (!maybeLhs)
       return rewriter.notifyMatchFailure(
           op, "LHS is not a sign- or zero- extended i8");
 
-    auto maybeRhs = getExtOperand<arith::ExtSIOp>(
-        op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+    auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
     if (maybeRhs) {
       if (mmlaOp == MMLA::Unsigned)
         mmlaOp = MMLA::Mixed;
     } else {
       if (mmlaOp == MMLA::Signed)
         mmlaOp = MMLA::MixedSwapped;
-      maybeRhs = getExtOperand<arith::ExtUIOp>(
-          op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+      maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
     }
     if (!maybeRhs)
       return rewriter.notifyMatchFailure(


        


More information about the Mlir-commits mailing list