[clang] [llvm] [mlir] [AArch64][SME] Improve codegen for aarch64.sme.cnts* when not in streaming mode (PR #154761)
Kerry McLaughlin via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 11 07:02:36 PDT 2025
https://github.com/kmclaughlin-arm updated https://github.com/llvm/llvm-project/pull/154761
>From 625925797e8e7a76471aeaa01150dbee8cf69de5 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Wed, 20 Aug 2025 09:33:17 +0000
Subject: [PATCH 01/11] RDSVL tests
---
.../CodeGen/AArch64/sme-intrinsics-rdsvl.ll | 49 +++++++++++++++++++
1 file changed, 49 insertions(+)
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index 5d10d7e13da14..b799f98981520 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -40,6 +40,55 @@ define i64 @sme_cntsd() {
ret i64 %v
}
+define i64 @sme_cntsb_mul() {
+; CHECK-LABEL: sme_cntsb_mul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: lsl x0, x8, #1
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsb()
+ %res = mul i64 %v, 2
+ ret i64 %res
+}
+
+define i64 @sme_cntsh_mul() {
+; CHECK-LABEL: sme_cntsh_mul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: lsr x8, x8, #1
+; CHECK-NEXT: add x0, x8, x8, lsl #2
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsh()
+ %res = mul i64 %v, 5
+ ret i64 %res
+}
+
+define i64 @sme_cntsw_mul() {
+; CHECK-LABEL: sme_cntsw_mul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: lsr x8, x8, #2
+; CHECK-NEXT: lsl x9, x8, #3
+; CHECK-NEXT: sub x0, x9, x8
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsw()
+ %res = mul i64 %v, 7
+ ret i64 %res
+}
+
+define i64 @sme_cntsd_mul() {
+; CHECK-LABEL: sme_cntsd_mul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: rdsvl x8, #1
+; CHECK-NEXT: lsr x8, x8, #3
+; CHECK-NEXT: add x8, x8, x8, lsl #1
+; CHECK-NEXT: lsl x0, x8, #2
+; CHECK-NEXT: ret
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %res = mul i64 %v, 12
+ ret i64 %res
+}
+
declare i64 @llvm.aarch64.sme.cntsb()
declare i64 @llvm.aarch64.sme.cntsh()
declare i64 @llvm.aarch64.sme.cntsw()
>From fd8ff8ab97a5d876811b11c43d1e6d6c19100399 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Wed, 13 Aug 2025 14:13:12 +0000
Subject: [PATCH 02/11] [AArch64][SME] Improve codegen for aarch64.sme.cnts*
when not in streaming mode
---
.../Target/AArch64/AArch64ISelLowering.cpp | 36 ++++++++++---------
.../lib/Target/AArch64/AArch64SMEInstrInfo.td | 23 ++++++++++++
.../CodeGen/AArch64/sme-intrinsics-rdsvl.ll | 20 +++++------
3 files changed, 51 insertions(+), 28 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 23328ed57fb36..6d65d6354b462 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6266,25 +6266,26 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_clz:
return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
- case Intrinsic::aarch64_sme_cntsb:
- return DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
- DAG.getConstant(1, DL, MVT::i32));
+ case Intrinsic::aarch64_sme_cntsb: {
+ SDValue Cntd = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+ DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
+ return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
+ DAG.getConstant(8, DL, MVT::i64));
+ }
case Intrinsic::aarch64_sme_cntsh: {
- SDValue One = DAG.getConstant(1, DL, MVT::i32);
- SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(), One);
- return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes, One);
+ SDValue Cntd = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+ DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
+ return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
+ DAG.getConstant(4, DL, MVT::i64));
}
case Intrinsic::aarch64_sme_cntsw: {
- SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
- DAG.getConstant(1, DL, MVT::i32));
- return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes,
- DAG.getConstant(2, DL, MVT::i32));
- }
- case Intrinsic::aarch64_sme_cntsd: {
- SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
- DAG.getConstant(1, DL, MVT::i32));
- return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes,
- DAG.getConstant(3, DL, MVT::i32));
+ SDValue Cntd = DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+ DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
+ return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
+ DAG.getConstant(2, DL, MVT::i64));
}
case Intrinsic::aarch64_sve_cnt: {
SDValue Data = Op.getOperand(3);
@@ -19200,6 +19201,9 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
if (ConstValue.sge(1) && ConstValue.sle(16))
return SDValue();
+ if (getIntrinsicID(N0.getNode()) == Intrinsic::aarch64_sme_cntsd)
+ return SDValue();
+
// Multiplication of a power of two plus/minus one can be done more
// cheaply as shift+add/sub. For now, this is true unilaterally. If
// future CPUs have a cheaper MADD instruction, this may need to be
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 0d8cb3a76d0be..aecfe37cad823 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -127,12 +127,35 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
+def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
+def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
+def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
+def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
+
let Predicates = [HasSMEandIsNonStreamingSafe] in {
def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>;
def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>;
def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
+
+// e.g. cntsb() * imm
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))),
+ (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))),
+ (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))),
+ (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))),
+ (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
+
+// e.g. cntsb()
+def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
+def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
+def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>;
+
+// Generic pattern for cntsd (RDSVL #1 >> 3)
+def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
}
let Predicates = [HasSME] in {
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index b799f98981520..8253db1d488e7 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -44,7 +44,8 @@ define i64 @sme_cntsb_mul() {
; CHECK-LABEL: sme_cntsb_mul:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #1
-; CHECK-NEXT: lsl x0, x8, #1
+; CHECK-NEXT: lsr x8, x8, #3
+; CHECK-NEXT: lsl x0, x8, #4
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsb()
%res = mul i64 %v, 2
@@ -54,9 +55,8 @@ define i64 @sme_cntsb_mul() {
define i64 @sme_cntsh_mul() {
; CHECK-LABEL: sme_cntsh_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x8, #1
-; CHECK-NEXT: lsr x8, x8, #1
-; CHECK-NEXT: add x0, x8, x8, lsl #2
+; CHECK-NEXT: rdsvl x8, #5
+; CHECK-NEXT: lsr x0, x8, #1
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsh()
%res = mul i64 %v, 5
@@ -66,10 +66,8 @@ define i64 @sme_cntsh_mul() {
define i64 @sme_cntsw_mul() {
; CHECK-LABEL: sme_cntsw_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x8, #1
-; CHECK-NEXT: lsr x8, x8, #2
-; CHECK-NEXT: lsl x9, x8, #3
-; CHECK-NEXT: sub x0, x9, x8
+; CHECK-NEXT: rdsvl x8, #7
+; CHECK-NEXT: lsr x0, x8, #2
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsw()
%res = mul i64 %v, 7
@@ -79,10 +77,8 @@ define i64 @sme_cntsw_mul() {
define i64 @sme_cntsd_mul() {
; CHECK-LABEL: sme_cntsd_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x8, #1
-; CHECK-NEXT: lsr x8, x8, #3
-; CHECK-NEXT: add x8, x8, x8, lsl #1
-; CHECK-NEXT: lsl x0, x8, #2
+; CHECK-NEXT: rdsvl x8, #3
+; CHECK-NEXT: lsr x0, x8, #1
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%res = mul i64 %v, 12
>From 65da7184ac4d9dfe53989f233e77206c0767215e Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 21 Aug 2025 14:04:18 +0000
Subject: [PATCH 03/11] - Replace cnts[b|h|w] builtins with cntsd intrinsic in
Clang - Remove cnts[b|h|w] intrinsics in LLVM - Add patterns for cntsd
---
clang/include/clang/Basic/arm_sme.td | 15 +++--
clang/lib/CodeGen/TargetBuiltins/ARM.cpp | 30 +++++++++-
.../AArch64/sme-intrinsics/acle_sme_cnt.c | 42 ++++++++------
llvm/include/llvm/IR/IntrinsicsAArch64.td | 9 +--
.../Target/AArch64/AArch64ISelDAGToDAG.cpp | 20 +++++++
.../Target/AArch64/AArch64ISelLowering.cpp | 21 -------
.../lib/Target/AArch64/AArch64SMEInstrInfo.td | 28 ++++-----
.../AArch64/AArch64TargetTransformInfo.cpp | 16 ++----
.../CodeGen/AArch64/sme-intrinsics-rdsvl.ll | 57 ++++++++++---------
.../sme-streaming-interface-remarks.ll | 4 +-
.../AArch64/sme-streaming-interface.ll | 7 ++-
.../sme-intrinsic-opts-counting-elems.ll | 45 ---------------
12 files changed, 136 insertions(+), 158 deletions(-)
diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td
index a4eb92e76968c..f853122994497 100644
--- a/clang/include/clang/Basic/arm_sme.td
+++ b/clang/include/clang/Basic/arm_sme.td
@@ -156,16 +156,15 @@ let SMETargetGuard = "sme2p1" in {
////////////////////////////////////////////////////////////////////////////////
// SME - Counting elements in a streaming vector
-multiclass ZACount<string n_suffix> {
- def NAME : SInst<"sv" # n_suffix, "nv", "", MergeNone,
- "aarch64_sme_" # n_suffix,
- [IsOverloadNone, IsStreamingCompatible]>;
+multiclass ZACount<string intr, string n_suffix> {
+ def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone,
+ intr, [IsOverloadNone, IsStreamingCompatible]>;
}
-defm SVCNTSB : ZACount<"cntsb">;
-defm SVCNTSH : ZACount<"cntsh">;
-defm SVCNTSW : ZACount<"cntsw">;
-defm SVCNTSD : ZACount<"cntsd">;
+defm SVCNTSB : ZACount<"", "cntsb">;
+defm SVCNTSH : ZACount<"", "cntsh">;
+defm SVCNTSW : ZACount<"", "cntsw">;
+defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">;
////////////////////////////////////////////////////////////////////////////////
// SME - ADDHA/ADDVA
diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index 60413e7b18e85..217232db44b6f 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4304,9 +4304,10 @@ Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags &TypeFlags,
// size in bytes.
if (Ops.size() == 5) {
Function *StreamingVectorLength =
- CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb);
+ CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd);
llvm::Value *StreamingVectorLengthCall =
- Builder.CreateCall(StreamingVectorLength);
+ Builder.CreateMul(Builder.CreateCall(StreamingVectorLength),
+ llvm::ConstantInt::get(Int64Ty, 8), "svl");
llvm::Value *Mulvl =
Builder.CreateMul(StreamingVectorLengthCall, Ops[4], "mulvl");
// The type of the ptr parameter is void *, so use Int8Ty here.
@@ -4918,6 +4919,31 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
// Handle builtins which require their multi-vector operands to be swapped
swapCommutativeSMEOperands(BuiltinID, Ops);
+ auto isCntsBuiltin = [&](int64_t &Mul) {
+ switch (BuiltinID) {
+ default:
+ Mul = 0;
+ return false;
+ case SME::BI__builtin_sme_svcntsb:
+ Mul = 8;
+ return true;
+ case SME::BI__builtin_sme_svcntsh:
+ Mul = 4;
+ return true;
+ case SME::BI__builtin_sme_svcntsw:
+ Mul = 2;
+ return true;
+ }
+ };
+
+ int64_t Mul = 0;
+ if (isCntsBuiltin(Mul)) {
+ llvm::Value *Cntd =
+ Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd));
+ return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul),
+ "mulsvl", /* HasNUW */ true, /* HasNSW */ true);
+ }
+
// Should not happen!
if (Builtin->LLVMIntrinsic == 0)
return nullptr;
diff --git a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c
index c0b3e1a06b0ff..049c1742e5a9d 100644
--- a/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c
+++ b/clang/test/CodeGen/AArch64/sme-intrinsics/acle_sme_cnt.c
@@ -6,49 +6,55 @@
#include <arm_sme.h>
-// CHECK-C-LABEL: define dso_local i64 @test_svcntsb(
+// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsb(
// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// CHECK-C-NEXT: entry:
-// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-C-NEXT: ret i64 [[TMP0]]
+// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3
+// CHECK-C-NEXT: ret i64 [[MULSVL]]
//
-// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntsbv(
+// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntsbv(
// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0:[0-9]+]] {
// CHECK-CXX-NEXT: entry:
-// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-CXX-NEXT: ret i64 [[TMP0]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 3
+// CHECK-CXX-NEXT: ret i64 [[MULSVL]]
//
uint64_t test_svcntsb() {
return svcntsb();
}
-// CHECK-C-LABEL: define dso_local i64 @test_svcntsh(
+// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsh(
// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
-// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh()
-// CHECK-C-NEXT: ret i64 [[TMP0]]
+// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2
+// CHECK-C-NEXT: ret i64 [[MULSVL]]
//
-// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntshv(
+// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntshv(
// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
-// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsh()
-// CHECK-CXX-NEXT: ret i64 [[TMP0]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 2
+// CHECK-CXX-NEXT: ret i64 [[MULSVL]]
//
uint64_t test_svcntsh() {
return svcntsh();
}
-// CHECK-C-LABEL: define dso_local i64 @test_svcntsw(
+// CHECK-C-LABEL: define dso_local range(i64 0, -9223372036854775808) i64 @test_svcntsw(
// CHECK-C-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-C-NEXT: entry:
-// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw()
-// CHECK-C-NEXT: ret i64 [[TMP0]]
+// CHECK-C-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-C-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1
+// CHECK-C-NEXT: ret i64 [[MULSVL]]
//
-// CHECK-CXX-LABEL: define dso_local noundef i64 @_Z12test_svcntswv(
+// CHECK-CXX-LABEL: define dso_local noundef range(i64 0, -9223372036854775808) i64 @_Z12test_svcntswv(
// CHECK-CXX-SAME: ) local_unnamed_addr #[[ATTR0]] {
// CHECK-CXX-NEXT: entry:
-// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsw()
-// CHECK-CXX-NEXT: ret i64 [[TMP0]]
+// CHECK-CXX-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.aarch64.sme.cntsd()
+// CHECK-CXX-NEXT: [[MULSVL:%.*]] = shl nuw nsw i64 [[TMP0]], 1
+// CHECK-CXX-NEXT: ret i64 [[MULSVL]]
//
uint64_t test_svcntsw() {
return svcntsw();
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 6d53bf8b172d8..7c9aef52b3acf 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -3147,13 +3147,8 @@ let TargetPrefix = "aarch64" in {
// Counting elements
//
- class AdvSIMD_SME_CNTSB_Intrinsic
- : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>;
-
- def int_aarch64_sme_cntsb : AdvSIMD_SME_CNTSB_Intrinsic;
- def int_aarch64_sme_cntsh : AdvSIMD_SME_CNTSB_Intrinsic;
- def int_aarch64_sme_cntsw : AdvSIMD_SME_CNTSB_Intrinsic;
- def int_aarch64_sme_cntsd : AdvSIMD_SME_CNTSB_Intrinsic;
+ def int_aarch64_sme_cntsd
+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [], [IntrNoMem]>;
//
// PSTATE Functions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index bc786f415b554..4e8255bab9437 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -71,6 +71,9 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
template <signed Low, signed High, signed Scale>
bool SelectRDVLImm(SDValue N, SDValue &Imm);
+ template <signed Low, signed High>
+ bool SelectRDSVLShiftImm(SDValue N, SDValue &Imm);
+
bool SelectArithExtendedRegister(SDValue N, SDValue &Reg, SDValue &Shift);
bool SelectArithUXTXRegister(SDValue N, SDValue &Reg, SDValue &Shift);
bool SelectArithImmed(SDValue N, SDValue &Val, SDValue &Shift);
@@ -937,6 +940,23 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
return false;
}
+template <signed Low, signed High>
+bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
+ if (!isa<ConstantSDNode>(N))
+ return false;
+
+ int64_t ShlImm = cast<ConstantSDNode>(N)->getSExtValue();
+ if (ShlImm >= 3) {
+ int64_t MulImm = 1 << (ShlImm - 3);
+ if (MulImm >= Low && MulImm <= High) {
+ Imm = CurDAG->getSignedTargetConstant(MulImm, SDLoc(N), MVT::i32);
+ return true;
+ }
+ }
+
+ return false;
+}
+
/// SelectArithExtendedRegister - Select a "extended register" operand. This
/// operand folds in an extend followed by an optional left shift.
bool AArch64DAGToDAGISel::SelectArithExtendedRegister(SDValue N, SDValue &Reg,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6d65d6354b462..08f0ae0b2f783 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6266,27 +6266,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_clz:
return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
- case Intrinsic::aarch64_sme_cntsb: {
- SDValue Cntd = DAG.getNode(
- ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
- DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
- return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
- DAG.getConstant(8, DL, MVT::i64));
- }
- case Intrinsic::aarch64_sme_cntsh: {
- SDValue Cntd = DAG.getNode(
- ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
- DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
- return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
- DAG.getConstant(4, DL, MVT::i64));
- }
- case Intrinsic::aarch64_sme_cntsw: {
- SDValue Cntd = DAG.getNode(
- ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
- DAG.getConstant(Intrinsic::aarch64_sme_cntsd, DL, MVT::i64));
- return DAG.getNode(ISD::MUL, DL, MVT::i64, Cntd,
- DAG.getConstant(2, DL, MVT::i64));
- }
case Intrinsic::aarch64_sve_cnt: {
SDValue Data = Op.getOperand(3);
// CTPOP only supports integer operands.
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index aecfe37cad823..3b27203d45585 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -127,10 +127,12 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
-def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
-def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
-def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
-def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
+def sme_cntsb_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
+def sme_cntsh_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
+def sme_cntsw_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
+def sme_cntsd_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
+
+def sme_cnts_shl_imm : ComplexPattern<i64, 1, "SelectRDSVLShiftImm<1, 31>">;
let Predicates = [HasSMEandIsNonStreamingSafe] in {
def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>;
@@ -140,21 +142,21 @@ def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
// e.g. cntsb() * imm
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_mul_imm i64:$imm))),
(RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_mul_imm i64:$imm))),
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_mul_imm i64:$imm))),
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))),
+def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_mul_imm i64:$imm))),
(UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
-// e.g. cntsb()
-def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
-def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
-def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>;
+def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (sme_cnts_shl_imm i64:$imm))),
+ (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
-// Generic pattern for cntsd (RDSVL #1 >> 3)
+// cntsh, cntsw, cntsd
+def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
+def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 490f6391c15a0..38958796e2fe1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2102,15 +2102,15 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
}
static std::optional<Instruction *>
-instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts,
+instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II,
const AArch64Subtarget *ST) {
if (!ST->isStreaming())
return std::nullopt;
- // In streaming-mode, aarch64_sme_cnts is equivalent to aarch64_sve_cnt
+ // In streaming-mode, aarch64_sme_cntds is equivalent to aarch64_sve_cntd
// with SVEPredPattern::all
- Value *Cnt = IC.Builder.CreateElementCount(
- II.getType(), ElementCount::getScalable(NumElts));
+ Value *Cnt =
+ IC.Builder.CreateElementCount(II.getType(), ElementCount::getScalable(2));
Cnt->takeName(&II);
return IC.replaceInstUsesWith(II, Cnt);
}
@@ -2825,13 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_cntb:
return instCombineSVECntElts(IC, II, 16);
case Intrinsic::aarch64_sme_cntsd:
- return instCombineSMECntsElts(IC, II, 2, ST);
- case Intrinsic::aarch64_sme_cntsw:
- return instCombineSMECntsElts(IC, II, 4, ST);
- case Intrinsic::aarch64_sme_cntsh:
- return instCombineSMECntsElts(IC, II, 8, ST);
- case Intrinsic::aarch64_sme_cntsb:
- return instCombineSMECntsElts(IC, II, 16, ST);
+ return instCombineSMECntsElts(IC, II, ST);
case Intrinsic::aarch64_sve_ptest_any:
case Intrinsic::aarch64_sve_ptest_first:
case Intrinsic::aarch64_sve_ptest_last:
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index 8253db1d488e7..86d3e42deae09 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -1,54 +1,56 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme -verify-machineinstrs < %s | FileCheck %s
-define i64 @sme_cntsb() {
-; CHECK-LABEL: sme_cntsb:
+define i64 @cntsb() {
+; CHECK-LABEL: cntsb:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x0, #1
; CHECK-NEXT: ret
- %v = call i64 @llvm.aarch64.sme.cntsb()
- ret i64 %v
+ %1 = call i64 @llvm.aarch64.sme.cntsd()
+ %res = shl nuw nsw i64 %1, 3
+ ret i64 %res
}
-define i64 @sme_cntsh() {
-; CHECK-LABEL: sme_cntsh:
+define i64 @cntsh() {
+; CHECK-LABEL: cntsh:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: lsr x0, x8, #1
; CHECK-NEXT: ret
- %v = call i64 @llvm.aarch64.sme.cntsh()
- ret i64 %v
+ %1 = call i64 @llvm.aarch64.sme.cntsd()
+ %res = shl nuw nsw i64 %1, 2
+ ret i64 %res
}
-define i64 @sme_cntsw() {
-; CHECK-LABEL: sme_cntsw:
+define i64 @cntsw() {
+; CHECK-LABEL: cntsw:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: lsr x0, x8, #2
; CHECK-NEXT: ret
- %v = call i64 @llvm.aarch64.sme.cntsw()
- ret i64 %v
+ %1 = call i64 @llvm.aarch64.sme.cntsd()
+ %res = shl nuw nsw i64 %1, 1
+ ret i64 %res
}
-define i64 @sme_cntsd() {
-; CHECK-LABEL: sme_cntsd:
+define i64 @cntsd() {
+; CHECK-LABEL: cntsd:
; CHECK: // %bb.0:
; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: lsr x0, x8, #3
; CHECK-NEXT: ret
- %v = call i64 @llvm.aarch64.sme.cntsd()
- ret i64 %v
+ %res = call i64 @llvm.aarch64.sme.cntsd()
+ ret i64 %res
}
define i64 @sme_cntsb_mul() {
; CHECK-LABEL: sme_cntsb_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x8, #1
-; CHECK-NEXT: lsr x8, x8, #3
-; CHECK-NEXT: lsl x0, x8, #4
+; CHECK-NEXT: rdsvl x0, #2
; CHECK-NEXT: ret
- %v = call i64 @llvm.aarch64.sme.cntsb()
- %res = mul i64 %v, 2
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 3
+ %res = mul i64 %shl, 2
ret i64 %res
}
@@ -58,8 +60,9 @@ define i64 @sme_cntsh_mul() {
; CHECK-NEXT: rdsvl x8, #5
; CHECK-NEXT: lsr x0, x8, #1
; CHECK-NEXT: ret
- %v = call i64 @llvm.aarch64.sme.cntsh()
- %res = mul i64 %v, 5
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 2
+ %res = mul i64 %shl, 5
ret i64 %res
}
@@ -69,8 +72,9 @@ define i64 @sme_cntsw_mul() {
; CHECK-NEXT: rdsvl x8, #7
; CHECK-NEXT: lsr x0, x8, #2
; CHECK-NEXT: ret
- %v = call i64 @llvm.aarch64.sme.cntsw()
- %res = mul i64 %v, 7
+ %v = call i64 @llvm.aarch64.sme.cntsd()
+ %shl = shl nuw nsw i64 %v, 1
+ %res = mul i64 %shl, 7
ret i64 %res
}
@@ -85,7 +89,4 @@ define i64 @sme_cntsd_mul() {
ret i64 %res
}
-declare i64 @llvm.aarch64.sme.cntsb()
-declare i64 @llvm.aarch64.sme.cntsh()
-declare i64 @llvm.aarch64.sme.cntsw()
declare i64 @llvm.aarch64.sme.cntsd()
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll
index e1a474d898233..2806f864c7b25 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface-remarks.ll
@@ -76,14 +76,14 @@ entry:
%Data1 = alloca <vscale x 16 x i8>, align 16
%Data2 = alloca <vscale x 16 x i8>, align 16
%Data3 = alloca <vscale x 16 x i8>, align 16
- %0 = tail call i64 @llvm.aarch64.sme.cntsb()
+ %0 = tail call i64 @llvm.aarch64.sme.cntsd()
call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr noundef nonnull %Data3, i64 noundef %0)
%1 = load <vscale x 16 x i8>, ptr %Data1, align 16
%vecext = extractelement <vscale x 16 x i8> %1, i64 0
ret i8 %vecext
}
-declare i64 @llvm.aarch64.sme.cntsb()
+declare i64 @llvm.aarch64.sme.cntsd()
declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef)
diff --git a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
index 8c4d57e244e03..505a40c16653b 100644
--- a/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
+++ b/llvm/test/CodeGen/AArch64/sme-streaming-interface.ll
@@ -366,9 +366,10 @@ define i8 @call_to_non_streaming_pass_sve_objects(ptr nocapture noundef readnone
; CHECK-NEXT: stp d9, d8, [sp, #48] // 16-byte Folded Spill
; CHECK-NEXT: stp x29, x30, [sp, #64] // 16-byte Folded Spill
; CHECK-NEXT: addvl sp, sp, #-3
-; CHECK-NEXT: rdsvl x3, #1
+; CHECK-NEXT: rdsvl x8, #1
; CHECK-NEXT: addvl x0, sp, #2
; CHECK-NEXT: addvl x1, sp, #1
+; CHECK-NEXT: lsr x3, x8, #3
; CHECK-NEXT: mov x2, sp
; CHECK-NEXT: smstop sm
; CHECK-NEXT: bl foo
@@ -386,7 +387,7 @@ entry:
%Data1 = alloca <vscale x 16 x i8>, align 16
%Data2 = alloca <vscale x 16 x i8>, align 16
%Data3 = alloca <vscale x 16 x i8>, align 16
- %0 = tail call i64 @llvm.aarch64.sme.cntsb()
+ %0 = tail call i64 @llvm.aarch64.sme.cntsd()
call void @foo(ptr noundef nonnull %Data1, ptr noundef nonnull %Data2, ptr noundef nonnull %Data3, i64 noundef %0)
%1 = load <vscale x 16 x i8>, ptr %Data1, align 16
%vecext = extractelement <vscale x 16 x i8> %1, i64 0
@@ -421,7 +422,7 @@ entry:
ret void
}
-declare i64 @llvm.aarch64.sme.cntsb()
+declare i64 @llvm.aarch64.sme.cntsd()
declare void @foo(ptr noundef, ptr noundef, ptr noundef, i64 noundef)
declare void @bar(ptr noundef, i64 noundef, i64 noundef, i32 noundef, i32 noundef, float noundef, float noundef, double noundef, double noundef)
diff --git a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll
index f213c0b53f6ef..c1d12b825b72c 100644
--- a/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll
+++ b/llvm/test/Transforms/InstCombine/AArch64/sme-intrinsic-opts-counting-elems.ll
@@ -5,48 +5,6 @@
target triple = "aarch64-unknown-linux-gnu"
-define i64 @cntsb() {
-; CHECK-LABEL: @cntsb(
-; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsb()
-; CHECK-NEXT: ret i64 [[OUT]]
-;
-; CHECK-STREAMING-LABEL: @cntsb(
-; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 4
-; CHECK-STREAMING-NEXT: ret i64 [[OUT]]
-;
- %out = call i64 @llvm.aarch64.sme.cntsb()
- ret i64 %out
-}
-
-define i64 @cntsh() {
-; CHECK-LABEL: @cntsh(
-; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsh()
-; CHECK-NEXT: ret i64 [[OUT]]
-;
-; CHECK-STREAMING-LABEL: @cntsh(
-; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 3
-; CHECK-STREAMING-NEXT: ret i64 [[OUT]]
-;
- %out = call i64 @llvm.aarch64.sme.cntsh()
- ret i64 %out
-}
-
-define i64 @cntsw() {
-; CHECK-LABEL: @cntsw(
-; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsw()
-; CHECK-NEXT: ret i64 [[OUT]]
-;
-; CHECK-STREAMING-LABEL: @cntsw(
-; CHECK-STREAMING-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-STREAMING-NEXT: [[OUT:%.*]] = shl nuw i64 [[TMP1]], 2
-; CHECK-STREAMING-NEXT: ret i64 [[OUT]]
-;
- %out = call i64 @llvm.aarch64.sme.cntsw()
- ret i64 %out
-}
-
define i64 @cntsd() {
; CHECK-LABEL: @cntsd(
; CHECK-NEXT: [[OUT:%.*]] = call i64 @llvm.aarch64.sme.cntsd()
@@ -61,8 +19,5 @@ define i64 @cntsd() {
ret i64 %out
}
-declare i64 @llvm.aarch64.sve.cntsb()
-declare i64 @llvm.aarch64.sve.cntsh()
-declare i64 @llvm.aarch64.sve.cntsw()
declare i64 @llvm.aarch64.sve.cntsd()
>From 191a4de7c1bc6a2cde664572709c667a57643715 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Mon, 1 Sep 2025 13:04:42 +0000
Subject: [PATCH 04/11] - Remove cnts[b,h,w] intrinsics from MLIR and fix tests
- Remove ZACount class from arm_sme.td
---
clang/include/clang/Basic/arm_sme.td | 13 +++----
.../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td | 3 --
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 36 ++++++++++++-------
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 17 ++++++---
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir | 2 +-
mlir/test/Target/LLVMIR/arm-sme.mlir | 6 ----
6 files changed, 40 insertions(+), 37 deletions(-)
diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td
index f853122994497..5f6a6eaab80a3 100644
--- a/clang/include/clang/Basic/arm_sme.td
+++ b/clang/include/clang/Basic/arm_sme.td
@@ -156,15 +156,10 @@ let SMETargetGuard = "sme2p1" in {
////////////////////////////////////////////////////////////////////////////////
// SME - Counting elements in a streaming vector
-multiclass ZACount<string intr, string n_suffix> {
- def NAME : SInst<"sv"#n_suffix, "nv", "", MergeNone,
- intr, [IsOverloadNone, IsStreamingCompatible]>;
-}
-
-defm SVCNTSB : ZACount<"", "cntsb">;
-defm SVCNTSH : ZACount<"", "cntsh">;
-defm SVCNTSW : ZACount<"", "cntsw">;
-defm SVCNTSD : ZACount<"aarch64_sme_cntsd", "cntsd">;
+def SVCNTSB : SInst<"svcntsb", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>;
+def SVCNTSH : SInst<"svcntsh", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>;
+def SVCNTSW : SInst<"svcntsw", "nv", "", MergeNone, "", [IsOverloadNone, IsStreamingCompatible]>;
+def SVCNTSD : SInst<"svcntsd", "nv", "", MergeNone, "aarch64_sme_cntsd", [IsOverloadNone, IsStreamingCompatible]>;
////////////////////////////////////////////////////////////////////////////////
// SME - ADDHA/ADDVA
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index 06fb8511774e8..4d19fa5415ef0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -201,9 +201,6 @@ class ArmSME_IntrCountOp<string mnemonic>
/*traits*/[PredOpTrait<"`res` is i64", TypeIsPred<"res", I64>>],
/*numResults=*/1, /*overloadedResults=*/[]>;
-def LLVM_aarch64_sme_cntsb : ArmSME_IntrCountOp<"cntsb">;
-def LLVM_aarch64_sme_cntsh : ArmSME_IntrCountOp<"cntsh">;
-def LLVM_aarch64_sme_cntsw : ArmSME_IntrCountOp<"cntsw">;
def LLVM_aarch64_sme_cntsd : ArmSME_IntrCountOp<"cntsd">;
#endif // ARMSME_INTRINSIC_OPS
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 8a2e3b639aaa7..6b795b18211b2 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -822,7 +822,7 @@ struct OuterProductWideningOpConversion
}
};
-/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
+/// Lower `arm_sme.streaming_vl` to SME CNTSD intrinsic.
///
/// Example:
///
@@ -830,8 +830,10 @@ struct OuterProductWideningOpConversion
///
/// is converted to:
///
-/// %cnt = "arm_sme.intr.cntsh"() : () -> i64
-/// %0 = arith.index_cast %cnt : i64 to index
+/// %cnt = "arm_sme.intr.cntsd"() : () -> i64
+/// %0 = arith.constant 4 : i64
+/// %1 = arith.muli %cnt, %0 : i64
+/// %2 = arith.index_cast %1 : i64 to index
///
struct StreamingVLOpConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
@@ -845,15 +847,25 @@ struct StreamingVLOpConversion
auto loc = streamingVlOp.getLoc();
auto i64Type = rewriter.getI64Type();
auto *intrOp = [&]() -> Operation * {
+ auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
switch (streamingVlOp.getTypeSize()) {
- case arm_sme::TypeSize::Byte:
- return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type);
- case arm_sme::TypeSize::Half:
- return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type);
- case arm_sme::TypeSize::Word:
- return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type);
+ case arm_sme::TypeSize::Byte: {
+ auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8);
+ auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
+ return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
+ }
+ case arm_sme::TypeSize::Half: {
+ auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4);
+ auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
+ return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
+ }
+ case arm_sme::TypeSize::Word: {
+ auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2);
+ auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
+ return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
+ }
case arm_sme::TypeSize::Double:
- return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
+ return cntsd;
}
llvm_unreachable("unknown type size in StreamingVLOpConversion");
}();
@@ -964,9 +976,7 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
- arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
- arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
- arm_sme::aarch64_sme_cntsd>();
+ arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect,
/* The following are used to lower tile spills/fills */
vector::VectorDialect, scf::SCFDialect,
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 6a4d77e86ab58..4f3c1dad24b76 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -586,9 +586,10 @@ func.func @arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec
// -----
// CHECK-LABEL: @arm_sme_streaming_vl_bytes
-// CHECK: %[[COUNT:.*]] = "arm_sme.intr.cntsb"() : () -> i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[COUNT]] : i64 to index
-// CHECK: return %[[INDEX_COUNT]] : index
+// CHECK: %[[CONST:.*]] = arith.constant 8 : i64
+// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
func.func @arm_sme_streaming_vl_bytes() -> index {
%svl_b = arm_sme.streaming_vl <byte>
return %svl_b : index
@@ -597,7 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index {
// -----
// CHECK-LABEL: @arm_sme_streaming_vl_half_words
-// CHECK: "arm_sme.intr.cntsh"() : () -> i64
+// CHECK: %[[CONST:.*]] = arith.constant 4 : i64
+// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
func.func @arm_sme_streaming_vl_half_words() -> index {
%svl_h = arm_sme.streaming_vl <half>
return %svl_h : index
@@ -606,7 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index {
// -----
// CHECK-LABEL: @arm_sme_streaming_vl_words
-// CHECK: "arm_sme.intr.cntsw"() : () -> i64
+// CHECK: %[[CONST:.*]] = arith.constant 2 : i64
+// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
+// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
func.func @arm_sme_streaming_vl_words() -> index {
%svl_w = arm_sme.streaming_vl <word>
return %svl_w : index
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index 14821da838726..6f5b1d8c5d93d 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -36,6 +36,6 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
llvm.func @arm_sme_streaming_vl_invalid_return_type() -> i32 {
// expected-error @+1 {{failed to verify that `res` is i64}}
- %res = "arm_sme.intr.cntsb"() : () -> i32
+ %res = "arm_sme.intr.cntsd"() : () -> i32
llvm.return %res : i32
}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index aedb6730b06bb..0a13a75618a23 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -419,12 +419,6 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
// -----
llvm.func @arm_sme_streaming_vl() {
- // CHECK: call i64 @llvm.aarch64.sme.cntsb()
- %svl_b = "arm_sme.intr.cntsb"() : () -> i64
- // CHECK: call i64 @llvm.aarch64.sme.cntsh()
- %svl_h = "arm_sme.intr.cntsh"() : () -> i64
- // CHECK: call i64 @llvm.aarch64.sme.cntsw()
- %svl_w = "arm_sme.intr.cntsw"() : () -> i64
// CHECK: call i64 @llvm.aarch64.sme.cntsd()
%svl_d = "arm_sme.intr.cntsd"() : () -> i64
llvm.return
>From e2d91e5e53043b916961940a01f2d65e4ac7b752 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 4 Sep 2025 10:23:08 +0000
Subject: [PATCH 05/11] - Remove lambda from StreamingVLOpConversion - Add
getSizeInBytes helper
---
clang/lib/CodeGen/TargetBuiltins/ARM.cpp | 17 ++++------
.../Target/AArch64/AArch64ISelDAGToDAG.cpp | 3 ++
.../AArch64/AArch64TargetTransformInfo.cpp | 6 ++--
.../include/mlir/Dialect/ArmSME/Utils/Utils.h | 3 ++
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 32 ++++---------------
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 15 +++++++++
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 18 +++++------
7 files changed, 46 insertions(+), 48 deletions(-)
diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index 217232db44b6f..de1bdb335469d 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4919,25 +4919,20 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
// Handle builtins which require their multi-vector operands to be swapped
swapCommutativeSMEOperands(BuiltinID, Ops);
- auto isCntsBuiltin = [&](int64_t &Mul) {
+ auto isCntsBuiltin = [&]() {
switch (BuiltinID) {
default:
- Mul = 0;
- return false;
+ return 0;
case SME::BI__builtin_sme_svcntsb:
- Mul = 8;
- return true;
+ return 8;
case SME::BI__builtin_sme_svcntsh:
- Mul = 4;
- return true;
+ return 4;
case SME::BI__builtin_sme_svcntsw:
- Mul = 2;
- return true;
+ return 2;
}
};
- int64_t Mul = 0;
- if (isCntsBuiltin(Mul)) {
+ if (auto Mul = isCntsBuiltin()) {
llvm::Value *Cntd =
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd));
return Builder.CreateMul(Cntd, llvm::ConstantInt::get(Int64Ty, Mul),
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 4e8255bab9437..8af10ef8dadc9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -940,6 +940,9 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
return false;
}
+// Given cntsd = (rdsvl, #1) >> 3, attempt to return a suitable multiplier
+// for RDSVL to calculate the streaming vector length in bytes * N. i.e.
+// rdsvl, #(ShlImm - 3)
template <signed Low, signed High>
bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
if (!isa<ConstantSDNode>(N))
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 38958796e2fe1..d4c7cb11a70a3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2102,8 +2102,8 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
}
static std::optional<Instruction *>
-instCombineSMECntsElts(InstCombiner &IC, IntrinsicInst &II,
- const AArch64Subtarget *ST) {
+instCombineSMECntsd(InstCombiner &IC, IntrinsicInst &II,
+ const AArch64Subtarget *ST) {
if (!ST->isStreaming())
return std::nullopt;
@@ -2825,7 +2825,7 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_cntb:
return instCombineSVECntElts(IC, II, 16);
case Intrinsic::aarch64_sme_cntsd:
- return instCombineSMECntsElts(IC, II, ST);
+ return instCombineSMECntsd(IC, II, ST);
case Intrinsic::aarch64_sve_ptest_any:
case Intrinsic::aarch64_sve_ptest_first:
case Intrinsic::aarch64_sve_ptest_last:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..b57b27de4e1de 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -32,6 +32,9 @@ namespace mlir::arm_sme {
constexpr unsigned MinStreamingVectorLengthInBits = 128;
+/// Return the size represented by arm_sme::TypeSize in bytes.
+unsigned getSizeInBytes(TypeSize type);
+
/// Return minimum number of elements for the given element `type` in
/// a vector of SVL bits.
unsigned getSMETileSliceMinNumElts(Type type);
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 6b795b18211b2..a36f8f09ceada 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -846,31 +846,13 @@ struct StreamingVLOpConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = streamingVlOp.getLoc();
auto i64Type = rewriter.getI64Type();
- auto *intrOp = [&]() -> Operation * {
- auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
- switch (streamingVlOp.getTypeSize()) {
- case arm_sme::TypeSize::Byte: {
- auto mul = arith::ConstantIndexOp::create(rewriter, loc, 8);
- auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
- return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
- }
- case arm_sme::TypeSize::Half: {
- auto mul = arith::ConstantIndexOp::create(rewriter, loc, 4);
- auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
- return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
- }
- case arm_sme::TypeSize::Word: {
- auto mul = arith::ConstantIndexOp::create(rewriter, loc, 2);
- auto mul64 = arith::IndexCastOp::create(rewriter, loc, i64Type, mul);
- return arith::MulIOp::create(rewriter, loc, cntsd, mul64);
- }
- case arm_sme::TypeSize::Double:
- return cntsd;
- }
- llvm_unreachable("unknown type size in StreamingVLOpConversion");
- }();
- rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
- streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
+ auto cntsd = arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type);
+ auto cntsdIdx = arith::IndexCastOp::create(rewriter, loc,
+ rewriter.getIndexType(), cntsd);
+ auto scale = arith::ConstantIndexOp::create(
+ rewriter, loc,
+ 8 / arm_sme::getSizeInBytes(streamingVlOp.getTypeSize()));
+ rewriter.replaceOpWithNewOp<arith::MulIOp>(streamingVlOp, cntsdIdx, scale);
return success();
}
};
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index e5e1312f0eb04..92f4e4f63c200 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -14,6 +14,21 @@
namespace mlir::arm_sme {
+unsigned getSizeInBytes(TypeSize type) {
+ switch (type) {
+ case arm_sme::TypeSize::Byte:
+ return 1;
+ case arm_sme::TypeSize::Half:
+ return 2;
+ case arm_sme::TypeSize::Word:
+ return 4;
+ case arm_sme::TypeSize::Double:
+ return 8;
+ default:
+ llvm_unreachable("unknown type size");
+ }
+}
+
unsigned getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 4f3c1dad24b76..fd8910265cd89 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -586,10 +586,10 @@ func.func @arm_sme_extract_tile_slice_ver_i128(%tile_slice_index : index) -> vec
// -----
// CHECK-LABEL: @arm_sme_streaming_vl_bytes
-// CHECK: %[[CONST:.*]] = arith.constant 8 : i64
+// CHECK: %[[CONST:.*]] = arith.constant 8 : index
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
-// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
+// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
func.func @arm_sme_streaming_vl_bytes() -> index {
%svl_b = arm_sme.streaming_vl <byte>
return %svl_b : index
@@ -598,10 +598,10 @@ func.func @arm_sme_streaming_vl_bytes() -> index {
// -----
// CHECK-LABEL: @arm_sme_streaming_vl_half_words
-// CHECK: %[[CONST:.*]] = arith.constant 4 : i64
+// CHECK: %[[CONST:.*]] = arith.constant 4 : index
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
-// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
+// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
func.func @arm_sme_streaming_vl_half_words() -> index {
%svl_h = arm_sme.streaming_vl <half>
return %svl_h : index
@@ -610,10 +610,10 @@ func.func @arm_sme_streaming_vl_half_words() -> index {
// -----
// CHECK-LABEL: @arm_sme_streaming_vl_words
-// CHECK: %[[CONST:.*]] = arith.constant 2 : i64
+// CHECK: %[[CONST:.*]] = arith.constant 2 : index
// CHECK: %[[CNTSD:.*]] = "arm_sme.intr.cntsd"() : () -> i64
-// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD]], %[[CONST]] : i64
-// CHECK: %[[INDEX_COUNT:.*]] = arith.index_cast %[[MUL]] : i64 to index
+// CHECK: %[[CNTSD_IDX:.*]] = arith.index_cast %[[CNTSD]] : i64 to index
+// CHECK: %[[MUL:.*]] = arith.muli %[[CNTSD_IDX]], %[[CONST]] : index
func.func @arm_sme_streaming_vl_words() -> index {
%svl_w = arm_sme.streaming_vl <word>
return %svl_w : index
>From f5de33d9e1df0155b96fcc54543bc754ab044907 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Fri, 5 Sep 2025 10:27:14 +0000
Subject: [PATCH 06/11] - Fix 'default label in switch' build failure
---
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 92f4e4f63c200..e64ae42204fa0 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -24,9 +24,9 @@ unsigned getSizeInBytes(TypeSize type) {
return 4;
case arm_sme::TypeSize::Double:
return 8;
- default:
- llvm_unreachable("unknown type size");
}
+ llvm_unreachable("unknown type size");
+ return 0;
}
unsigned getSMETileSliceMinNumElts(Type type) {
>From a91144e0f096bcb49bd46eb5286a666dbd45696b Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Fri, 5 Sep 2025 14:02:45 +0000
Subject: [PATCH 07/11] - Fix comments in AArch64ISelDAGToDAG & ArmSMEToLLVM
---
llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp | 5 ++---
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 6 +++---
2 files changed, 5 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 8af10ef8dadc9..8ab313dfed46a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -940,9 +940,8 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
return false;
}
-// Given cntsd = (rdsvl, #1) >> 3, attempt to return a suitable multiplier
-// for RDSVL to calculate the streaming vector length in bytes * N. i.e.
-// rdsvl, #(ShlImm - 3)
+// Given `cntsd = (rdsvl, #1) >> 3`, attempt to return a suitable multiplier
+// for RDSVL to calculate `cntsd << N`, i.e. `rdsvl, #(N - 3)`.
template <signed Low, signed High>
bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
if (!isa<ConstantSDNode>(N))
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index a36f8f09ceada..033e9ae1f4d4c 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -831,9 +831,9 @@ struct OuterProductWideningOpConversion
/// is converted to:
///
/// %cnt = "arm_sme.intr.cntsd"() : () -> i64
-/// %0 = arith.constant 4 : i64
-/// %1 = arith.muli %cnt, %0 : i64
-/// %2 = arith.index_cast %1 : i64 to index
+/// %scale = arith.constant 4 : index
+/// %cntIndex = arith.index_cast %cnt : i64 to index
+/// %0 = arith.muli %cntIndex, %scale : index
///
struct StreamingVLOpConversion
: public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
>From 166f030cf2214d6a2fb1c55cd7b5fe981b26bd88 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Fri, 5 Sep 2025 15:14:36 +0000
Subject: [PATCH 08/11] - Fix comment in SelectRDSVLShiftImm
---
llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 8ab313dfed46a..f8cea7d511931 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -941,7 +941,7 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
}
// Given `cntsd = (rdsvl, #1) >> 3`, attempt to return a suitable multiplier
-// for RDSVL to calculate `cntsd << N`, i.e. `rdsvl, #(N - 3)`.
+// for RDSVL to calculate `cntsd << N`, i.e. `rdsvl, #(1 << (N - 3))`.
template <signed Low, signed High>
bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
if (!isa<ConstantSDNode>(N))
>From 773cc76fffc2233dce573b8bd50883290904d8e3 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Mon, 8 Sep 2025 16:28:41 +0000
Subject: [PATCH 09/11] - Add lowering of cntsd back into
LowerINTRINSIC_WO_CHAIN - Remove patterns for cntsd and add for AArch64rdsvl
- Add nsw/nuw flags to mul of cntsd in EmitSMELd1St1
---
clang/lib/CodeGen/TargetBuiltins/ARM.cpp | 3 ++-
.../Target/AArch64/AArch64ISelDAGToDAG.cpp | 14 +++++---------
.../Target/AArch64/AArch64ISelLowering.cpp | 13 ++++++++++---
.../lib/Target/AArch64/AArch64SMEInstrInfo.td | 19 ++-----------------
.../CodeGen/AArch64/sme-intrinsics-rdsvl.ll | 19 ++++++++-----------
5 files changed, 27 insertions(+), 41 deletions(-)
diff --git a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
index de1bdb335469d..734d925c0bb7c 100644
--- a/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/ARM.cpp
@@ -4307,7 +4307,8 @@ Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags &TypeFlags,
CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsd);
llvm::Value *StreamingVectorLengthCall =
Builder.CreateMul(Builder.CreateCall(StreamingVectorLength),
- llvm::ConstantInt::get(Int64Ty, 8), "svl");
+ llvm::ConstantInt::get(Int64Ty, 8), "svl",
+ /* HasNUW */ true, /* HasNSW */ true);
llvm::Value *Mulvl =
Builder.CreateMul(StreamingVectorLengthCall, Ops[4], "mulvl");
// The type of the ptr parameter is void *, so use Int8Ty here.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index f8cea7d511931..1c20a8240d688 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -940,20 +940,16 @@ bool AArch64DAGToDAGISel::SelectRDVLImm(SDValue N, SDValue &Imm) {
return false;
}
-// Given `cntsd = (rdsvl, #1) >> 3`, attempt to return a suitable multiplier
-// for RDSVL to calculate `cntsd << N`, i.e. `rdsvl, #(1 << (N - 3))`.
+// Returns a suitable RDSVL multiplier from a left shift.
template <signed Low, signed High>
bool AArch64DAGToDAGISel::SelectRDSVLShiftImm(SDValue N, SDValue &Imm) {
if (!isa<ConstantSDNode>(N))
return false;
- int64_t ShlImm = cast<ConstantSDNode>(N)->getSExtValue();
- if (ShlImm >= 3) {
- int64_t MulImm = 1 << (ShlImm - 3);
- if (MulImm >= Low && MulImm <= High) {
- Imm = CurDAG->getSignedTargetConstant(MulImm, SDLoc(N), MVT::i32);
- return true;
- }
+ int64_t MulImm = 1 << cast<ConstantSDNode>(N)->getSExtValue();
+ if (MulImm >= Low && MulImm <= High) {
+ Imm = CurDAG->getSignedTargetConstant(MulImm, SDLoc(N), MVT::i32);
+ return true;
}
return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 08f0ae0b2f783..27d9769dd8367 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6266,6 +6266,16 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_clz:
return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, DL, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sme_cntsd: {
+ auto Flags = SDNodeFlags();
+ Flags.setNoUnsignedWrap(true);
+ Flags.setNoSignedWrap(true);
+ Flags.setExact(true);
+ SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
+ DAG.getConstant(1, DL, MVT::i32));
+ return DAG.getNode(ISD::SRL, DL, Op.getValueType(), Bytes,
+ DAG.getConstant(3, DL, MVT::i32), Flags);
+ }
case Intrinsic::aarch64_sve_cnt: {
SDValue Data = Op.getOperand(3);
// CTPOP only supports integer operands.
@@ -19180,9 +19190,6 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
if (ConstValue.sge(1) && ConstValue.sle(16))
return SDValue();
- if (getIntrinsicID(N0.getNode()) == Intrinsic::aarch64_sme_cntsd)
- return SDValue();
-
// Multiplication of a power of two plus/minus one can be done more
// cheaply as shift+add/sub. For now, this is true unilaterally. If
// future CPUs have a cheaper MADD instruction, this may need to be
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 3b27203d45585..5f949843e41c0 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -139,25 +139,10 @@ def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>
def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>;
def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
-def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
-
-// e.g. cntsb() * imm
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_mul_imm i64:$imm))),
- (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_mul_imm i64:$imm))),
- (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_mul_imm i64:$imm))),
- (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
-def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_mul_imm i64:$imm))),
- (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
-
-def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (sme_cnts_shl_imm i64:$imm))),
+def : Pat<(i64 (shl (AArch64rdsvl (i32 1)), (sme_cnts_shl_imm i64:$imm))),
(RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
-// cntsh, cntsw, cntsd
-def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
-def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
-def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
+def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
}
let Predicates = [HasSME] in {
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
index 86d3e42deae09..06c53d8070781 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-rdsvl.ll
@@ -46,46 +46,43 @@ define i64 @cntsd() {
define i64 @sme_cntsb_mul() {
; CHECK-LABEL: sme_cntsb_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x0, #2
+; CHECK-NEXT: rdsvl x0, #4
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 3
- %res = mul i64 %shl, 2
+ %res = mul nuw nsw i64 %shl, 4
ret i64 %res
}
define i64 @sme_cntsh_mul() {
; CHECK-LABEL: sme_cntsh_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x8, #5
-; CHECK-NEXT: lsr x0, x8, #1
+; CHECK-NEXT: rdsvl x0, #4
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 2
- %res = mul i64 %shl, 5
+ %res = mul nuw nsw i64 %shl, 8
ret i64 %res
}
define i64 @sme_cntsw_mul() {
; CHECK-LABEL: sme_cntsw_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x8, #7
-; CHECK-NEXT: lsr x0, x8, #2
+; CHECK-NEXT: rdsvl x0, #4
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
%shl = shl nuw nsw i64 %v, 1
- %res = mul i64 %shl, 7
+ %res = mul nuw nsw i64 %shl, 16
ret i64 %res
}
define i64 @sme_cntsd_mul() {
; CHECK-LABEL: sme_cntsd_mul:
; CHECK: // %bb.0:
-; CHECK-NEXT: rdsvl x8, #3
-; CHECK-NEXT: lsr x0, x8, #1
+; CHECK-NEXT: rdsvl x0, #4
; CHECK-NEXT: ret
%v = call i64 @llvm.aarch64.sme.cntsd()
- %res = mul i64 %v, 12
+ %res = mul nuw nsw i64 %v, 32
ret i64 %res
}
>From d258ba72efef57bbaaea988f2b76231ef7452087 Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Wed, 10 Sep 2025 14:09:37 +0000
Subject: [PATCH 10/11] - Remove sme_cnts*_mul_imm patterns
---
llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td | 9 ++-------
1 file changed, 2 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 5f949843e41c0..6313aba9a435e 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -127,19 +127,14 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
-def sme_cntsb_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
-def sme_cntsh_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
-def sme_cntsw_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
-def sme_cntsd_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
-
-def sme_cnts_shl_imm : ComplexPattern<i64, 1, "SelectRDSVLShiftImm<1, 31>">;
+def sme_rdsvl_shl_imm : ComplexPattern<i64, 1, "SelectRDSVLShiftImm<1, 31>">;
let Predicates = [HasSMEandIsNonStreamingSafe] in {
def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>;
def ADDSPL_XXI : sve_int_arith_vl<0b1, "addspl", /*streaming_sve=*/0b1>;
def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
-def : Pat<(i64 (shl (AArch64rdsvl (i32 1)), (sme_cnts_shl_imm i64:$imm))),
+def : Pat<(i64 (shl (AArch64rdsvl (i32 1)), (sme_rdsvl_shl_imm i64:$imm))),
(RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
>From ce9d51cd856867454630565cb65750238e1f679d Mon Sep 17 00:00:00 2001
From: Kerry McLaughlin <kerry.mclaughlin at arm.com>
Date: Thu, 11 Sep 2025 13:57:53 +0000
Subject: [PATCH 11/11] - Remove setNoSignedWrap/setNoUnsignedWrap from cntsd
lowering
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 27d9769dd8367..344ddcea371b7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -6268,8 +6268,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
case Intrinsic::aarch64_sme_cntsd: {
auto Flags = SDNodeFlags();
- Flags.setNoUnsignedWrap(true);
- Flags.setNoSignedWrap(true);
Flags.setExact(true);
SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, DL, Op.getValueType(),
DAG.getConstant(1, DL, MVT::i32));
More information about the llvm-commits
mailing list