[mlir] [clang] [llvm] [AArch64][SME] Remove immediate argument restriction for svldr and svstr (PR #68565)

Sam Tebbs via cfe-commits cfe-commits at lists.llvm.org
Wed Nov 15 05:14:26 PST 2023


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/68565

>From 83e20904c206980285c4ee9d0227706803147654 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 6 Oct 2023 17:09:36 +0100
Subject: [PATCH 01/12] [AArch64][SME] Remove immediate argument restriction
 for svldr and svstr

The svldr_vnum_za and svstr_vnum_za builtins/intrinsics currently
require that the vnum argument be an immediate, since the instructions
take an immediate vector number. However, we emit 0 as the immediate
for the instruction no matter what, and instead modify the base register.

This patch removes that restriction on the argument, so that the
argument can be a non-immediate. If an appropriate immediate was
passed to the builtin then CGBuiltin passes that directly to the LLVM
intrinsic, otherwise it modifies the base register as is existing
behaviour.
---
 clang/lib/CodeGen/CGBuiltin.cpp               | 45 ++++++++----
 .../aarch64-sme-intrinsics/acle_sme_ldr.c     | 71 ++++++++-----------
 .../aarch64-sme-intrinsics/acle_sme_str.c     | 51 ++++---------
 llvm/include/llvm/IR/IntrinsicsAArch64.td     |  2 +-
 llvm/lib/Target/AArch64/SMEInstrFormats.td    | 10 +--
 .../CostModel/ARM/unaligned_double_load.ll    | 59 +++++++++++++++
 .../CodeGen/AArch64/sme-intrinsics-loads.ll   | 33 +++++++--
 7 files changed, 165 insertions(+), 106 deletions(-)
 create mode 100644 llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 09309a3937fb613..8444aea8c8ac4b6 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -9815,6 +9815,11 @@ Value *CodeGenFunction::EmitSVEMaskedStore(const CallExpr *E,
   return Store;
 }
 
+Value *CodeGenFunction::EmitTileslice(Value *Offset, Value *Base) {
+  llvm::Value *CastOffset = Builder.CreateIntCast(Offset, Int64Ty, false);
+  return Builder.CreateAdd(Base, CastOffset, "tileslice");
+}
+
 Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags &TypeFlags,
                                       SmallVectorImpl<Value *> &Ops,
                                       unsigned IntID) {
@@ -9870,18 +9875,34 @@ Value *CodeGenFunction::EmitSMEZero(const SVETypeFlags &TypeFlags,
 Value *CodeGenFunction::EmitSMELdrStr(const SVETypeFlags &TypeFlags,
                                       SmallVectorImpl<Value *> &Ops,
                                       unsigned IntID) {
-  if (Ops.size() == 3) {
-    Function *Cntsb = CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb);
-    llvm::Value *CntsbCall = Builder.CreateCall(Cntsb, {}, "svlb");
-
-    llvm::Value *VecNum = Ops[2];
-    llvm::Value *MulVL = Builder.CreateMul(CntsbCall, VecNum, "mulvl");
-
-    Ops[1] = Builder.CreateGEP(Int8Ty, Ops[1], MulVL);
-    Ops[0] = Builder.CreateAdd(
-        Ops[0], Builder.CreateIntCast(VecNum, Int32Ty, true), "tileslice");
-    Ops.erase(&Ops[2]);
-  }
+  if (Ops.size() == 2) {
+    // Intrinsics without a vecnum also use this function, so just provide 0
+    Ops.push_back(Ops[1]);
+    Ops[1] = Builder.getInt32(0);
+  } else {
+    int Imm = -1;
+    if (ConstantInt* C = dyn_cast<ConstantInt>(Ops[2]))
+      if (C->getZExtValue() <= 15)
+          Imm = C->getZExtValue();
+
+    if (Imm != -1) {
+      Ops[2] = Ops[1];
+      Ops[1] = Builder.getInt32(Imm);
+    } else {
+      Function *Cntsb = CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb);
+      llvm::Value *CntsbCall = Builder.CreateCall(Cntsb, {}, "svlb");
+
+      llvm::Value *VecNum = Ops[2];
+      llvm::Value *MulVL = Builder.CreateMul(
+          CntsbCall,
+          VecNum,
+          "mulvl");
+
+      Ops[2] = Builder.CreateGEP(Int8Ty, Ops[1], MulVL);
+      Ops[1] = Builder.getInt32(0);
+      Ops[0] = Builder.CreateIntCast(EmitTileslice(Ops[0], VecNum), Int32Ty, false);
+    }
+   }
   Function *F = CGM.getIntrinsic(IntID, {});
   return Builder.CreateCall(F, Ops);
 }
diff --git a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c
index e85c47072f2df80..8e07cf1d11c19b2 100644
--- a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c
+++ b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c
@@ -6,57 +6,46 @@
 
 #include <arm_sme_draft_spec_subject_to_change.h>
 
-// CHECK-C-LABEL: define dso_local void @test_svldr_vnum_za(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-C-NEXT:    ret void
-//
-// CHECK-CXX-LABEL: define dso_local void @_Z18test_svldr_vnum_zajPKv(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-CXX-NEXT:    ret void
+// CHECK-C-LABEL: @test_svldr_vnum_za(
+// CHECK-CXX-LABEL: @_Z18test_svldr_vnum_zajPKv(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    ret void
 //
 void test_svldr_vnum_za(uint32_t slice_base, const void *ptr) {
   svldr_vnum_za(slice_base, ptr, 0);
 }
 
-// CHECK-C-LABEL: define dso_local void @test_svldr_vnum_za_1(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-C-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], 15
-// CHECK-C-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-C-NEXT:    [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-C-NEXT:    ret void
-//
-// CHECK-CXX-LABEL: define dso_local void @_Z20test_svldr_vnum_za_1jPKv(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-CXX-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], 15
-// CHECK-CXX-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-CXX-NEXT:    [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-CXX-NEXT:    ret void
+// CHECK-C-LABEL: @test_svldr_vnum_za_1(
+// CHECK-CXX-LABEL: @_Z20test_svldr_vnum_za_1jPKv(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], i32 15, ptr [[PTR]])
+// CHECK-NEXT:    ret void
 //
 void test_svldr_vnum_za_1(uint32_t slice_base, const void *ptr) {
   svldr_vnum_za(slice_base, ptr, 15);
 }
 
-// CHECK-C-LABEL: define dso_local void @test_svldr_za(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-C-NEXT:    ret void
+// CHECK-C-LABEL: @test_svldr_vnum_za_var(
+// CHECK-CXX-LABEL: @_Z22test_svldr_vnum_za_varjPKvm(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
+// CHECK-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
+// CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
+// CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
+// CHECK-NEXT:    [[TMP2:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[TMP2]], i32 0, ptr [[TMP0]])
+// CHECK-NEXT:    ret void
 //
-// CHECK-CXX-LABEL: define dso_local void @_Z13test_svldr_zajPKv(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-CXX-NEXT:    ret void
+void test_svldr_vnum_za_var(uint32_t slice_base, const void *ptr, uint64_t vnum) {
+  svldr_vnum_za(slice_base, ptr, vnum);
+}
+
+// CHECK-C-LABEL: @test_svldr_za(
+// CHECK-CXX-LABEL: @_Z13test_svldr_zajPKv(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    ret void
 //
 void test_svldr_za(uint32_t slice_base, const void *ptr) {
   svldr_za(slice_base, ptr);
@@ -87,5 +76,3 @@ void test_svldr_za(uint32_t slice_base, const void *ptr) {
 void test_svldr_vnum_za_var(uint32_t slice_base, const void *ptr, int64_t vnum) {
   svldr_vnum_za(slice_base, ptr, vnum);
 }
-//// NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-// CHECK: {{.*}}
diff --git a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c
index e53a3c6c57de323..532f570b6aaa444 100644
--- a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c
+++ b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c
@@ -6,57 +6,32 @@
 
 #include <arm_sme_draft_spec_subject_to_change.h>
 
-// CHECK-C-LABEL: define dso_local void @test_svstr_vnum_za(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-C-NEXT:    ret void
-//
-// CHECK-CXX-LABEL: define dso_local void @_Z18test_svstr_vnum_zajPv(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-CXX-NEXT:    ret void
+// CHECK-C-LABEL: @test_svstr_vnum_za(
+// CHECK-CXX-LABEL: @_Z18test_svstr_vnum_zajPv(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    ret void
 //
 void test_svstr_vnum_za(uint32_t slice_base, void *ptr) {
   svstr_vnum_za(slice_base, ptr, 0);
 }
 
 // CHECK-C-LABEL: define dso_local void @test_svstr_vnum_za_1(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-C-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], 15
-// CHECK-C-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-C-NEXT:    [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-C-NEXT:    ret void
-//
 // CHECK-CXX-LABEL: define dso_local void @_Z20test_svstr_vnum_za_1jPv(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-CXX-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], 15
-// CHECK-CXX-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-CXX-NEXT:    [[TILESLICE:%.*]] = add i32 [[SLICE_BASE]], 15
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-CXX-NEXT:    ret void
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], i32 15, ptr [[PTR]])
+// CHECK-NEXT:    ret void
 //
 void test_svstr_vnum_za_1(uint32_t slice_base, void *ptr) {
   svstr_vnum_za(slice_base, ptr, 15);
 }
 
 // CHECK-C-LABEL: define dso_local void @test_svstr_za(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-C-NEXT:    ret void
-//
 // CHECK-CXX-LABEL: define dso_local void @_Z13test_svstr_zajPv(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], ptr [[PTR]])
-// CHECK-CXX-NEXT:    ret void
+// CHECK-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    ret void
 //
 void test_svstr_za(uint32_t slice_base, void *ptr) {
   svstr_za(slice_base, ptr);
@@ -87,5 +62,3 @@ void test_svstr_za(uint32_t slice_base, void *ptr) {
 void test_svstr_vnum_za_var(uint32_t slice_base, void *ptr, int64_t vnum) {
   svstr_vnum_za(slice_base, ptr, vnum);
 }
-//// NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-// CHECK: {{.*}}
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 9164604f7d78cbc..543b5b6fa94d40d 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -2680,7 +2680,7 @@ let TargetPrefix = "aarch64" in {
 
   // Spill + fill
   class SME_LDR_STR_Intrinsic
-    : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty]>;
+    : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_aarch64_sme_ldr : SME_LDR_STR_Intrinsic;
   def int_aarch64_sme_str : SME_LDR_STR_Intrinsic;
 
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index f54d898aa69f7cd..143931a001a4ba1 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -795,8 +795,8 @@ multiclass sme_spill<string opcodestr> {
                   (!cast<Instruction>(NAME) MatrixOp:$ZAt,
                    MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>;
   // base
-  def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base),
-            (!cast<Instruction>(NAME) ZA, $idx, 0, $base, 0)>;
+  def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_15:$imm, GPR64sp:$base),
+            (!cast<Instruction>(NAME) ZA, $idx, $imm, $base, 0)>;
 }
 
 multiclass sme_fill<string opcodestr> {
@@ -806,7 +806,7 @@ multiclass sme_fill<string opcodestr> {
                    MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>;
   def NAME # _PSEUDO
       : Pseudo<(outs),
-               (ins MatrixIndexGPR32Op12_15:$idx, imm0_15:$imm4,
+               (ins MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_15:$imm4,
                     GPR64sp:$base), []>,
         Sched<[]> {
     // Translated to actual instruction in AArch64ISelLowering.cpp
@@ -814,8 +814,8 @@ multiclass sme_fill<string opcodestr> {
     let mayLoad = 1;
   }
   // base
-  def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base),
-            (!cast<Instruction>(NAME # _PSEUDO) $idx, 0, $base)>;
+  def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_15:$imm, GPR64sp:$base),
+            (!cast<Instruction>(NAME # _PSEUDO) $idx, $imm, $base)>;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll b/llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll
new file mode 100644
index 000000000000000..8d457220ea9c5ae
--- /dev/null
+++ b/llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll
@@ -0,0 +1,59 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
+; RUN: opt -passes="print<cost-model>" 2>&1 -disable-output -mtriple=thumbv6m-none-eabi < %s | FileCheck %s --check-prefix=CHECK-NOVEC
+; RUN: opt -passes="print<cost-model>" 2>&1 -disable-output -mtriple=thumbv7m-none-eabi -mcpu=cortex-m4 < %s | FileCheck %s --check-prefix=CHECK-FP
+
+define float @f(ptr %x) {
+; CHECK-NOVEC-LABEL: 'f'
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %a.0.copyload = load float, ptr %x, align 1
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float %a.0.copyload
+;
+; CHECK-FP-LABEL: 'f'
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %a.0.copyload = load float, ptr %x, align 1
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float %a.0.copyload
+;
+entry:
+  %a.0.copyload = load float, ptr %x, align 1
+  ret float %a.0.copyload
+}
+
+define float @ff(ptr %x, float %f) {
+; CHECK-NOVEC-LABEL: 'ff'
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store float %f, ptr %x, align 1
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float undef
+;
+; CHECK-FP-LABEL: 'ff'
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: store float %f, ptr %x, align 1
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float undef
+;
+entry:
+  store float %f, ptr %x, align 1
+  ret float undef
+}
+
+define double @d(ptr %x) {
+; CHECK-NOVEC-LABEL: 'd'
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %a.0.copyload = load double, ptr %x, align 1
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double %a.0.copyload
+;
+; CHECK-FP-LABEL: 'd'
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %a.0.copyload = load double, ptr %x, align 1
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double %a.0.copyload
+;
+entry:
+  %a.0.copyload = load double, ptr %x, align 1
+  ret double %a.0.copyload
+}
+
+define double @dd(ptr %x, double %f) {
+; CHECK-NOVEC-LABEL: 'dd'
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: store double %f, ptr %x, align 1
+; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double undef
+;
+; CHECK-FP-LABEL: 'dd'
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store double %f, ptr %x, align 1
+; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double undef
+;
+entry:
+  store double %f, ptr %x, align 1
+  ret double undef
+}
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
index c96aca366ed43f2..f5d25a3229a7f82 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
@@ -252,10 +252,28 @@ define void @ldr(ptr %ptr) {
 ; CHECK-NEXT:    mov w12, wzr
 ; CHECK-NEXT:    ldr za[w12, 0], [x0]
 ; CHECK-NEXT:    ret
-  call void @llvm.aarch64.sme.ldr(i32 0, ptr %ptr)
+  call void @llvm.aarch64.sme.ldr(i32 0, i32 0, ptr %ptr)
   ret void;
 }
 
+define void @ldr_vnum(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+; CHECK-LABEL: ldr_vnum:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    add w12, w2, w0
+; CHECK-NEXT:    madd x8, x8, x2, x1
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+entry:
+  %svlb = tail call i64 @llvm.aarch64.sme.cntsb()
+  %mulvl = mul i64 %svlb, %vnum
+  %0 = getelementptr i8, ptr %ptr, i64 %mulvl
+  %1 = trunc i64 %vnum to i32
+  %2 = add i32 %1, %tile_slice
+  tail call void @llvm.aarch64.sme.ldr(i32 %2, i32 0, ptr %0)
+  ret void
+}
+
 define void @ldr_with_off_15(ptr %ptr) {
 ; CHECK-LABEL: ldr_with_off_15:
 ; CHECK:       // %bb.0:
@@ -264,7 +282,7 @@ define void @ldr_with_off_15(ptr %ptr) {
 ; CHECK-NEXT:    ldr za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   %base = getelementptr i8, ptr %ptr, i64 15
-  call void @llvm.aarch64.sme.ldr(i32 15, ptr %base)
+  call void @llvm.aarch64.sme.ldr(i32 15, i32 0, ptr %base)
   ret void;
 }
 
@@ -278,7 +296,7 @@ define void @ldr_with_off_15mulvl(ptr %ptr) {
   %vscale = call i64 @llvm.vscale.i64()
   %mulvl = mul i64 %vscale, 240
   %base = getelementptr i8, ptr %ptr, i64 %mulvl
-  call void @llvm.aarch64.sme.ldr(i32 15, ptr %base)
+  call void @llvm.aarch64.sme.ldr(i32 15, i32 0, ptr %base)
   ret void;
 }
 
@@ -292,7 +310,7 @@ define void @ldr_with_off_16mulvl(ptr %ptr) {
   %vscale = call i64 @llvm.vscale.i64()
   %mulvl = mul i64 %vscale, 256
   %base = getelementptr i8, ptr %ptr, i64 %mulvl
-  call void @llvm.aarch64.sme.ldr(i32 16, ptr %base)
+  call void @llvm.aarch64.sme.ldr(i32 16, i32 0, ptr %base)
   ret void;
 }
 
@@ -302,13 +320,13 @@ define void @test_ld1_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src,
 ; CHECK-LABEL: test_ld1_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB14_1: // %for.body
+; CHECK-NEXT:  .LBB15_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 0]}, p0/z, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 1]}, p0/z, [x0]
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 2]}, p0/z, [x0]
-; CHECK-NEXT:    b.ne .LBB14_1
+; CHECK-NEXT:    b.ne .LBB15_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:
@@ -341,5 +359,6 @@ declare void @llvm.aarch64.sme.ld1w.vert(<vscale x 4 x i1>, ptr, i32, i32)
 declare void @llvm.aarch64.sme.ld1d.vert(<vscale x 2 x i1>, ptr, i32, i32)
 declare void @llvm.aarch64.sme.ld1q.vert(<vscale x 1 x i1>, ptr, i32, i32)
 
-declare void @llvm.aarch64.sme.ldr(i32, ptr)
+declare void @llvm.aarch64.sme.ldr(i32, i32, ptr)
 declare i64 @llvm.vscale.i64()
+declare i64 @llvm.aarch64.sme.cntsb()

>From a887758764c137c1ade44767aa9b9881c44d86e6 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Mon, 9 Oct 2023 09:52:28 +0100
Subject: [PATCH 02/12] fixup: remove erroneously included file

---
 .../CostModel/ARM/unaligned_double_load.ll    | 59 -------------------
 1 file changed, 59 deletions(-)
 delete mode 100644 llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll

diff --git a/llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll b/llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll
deleted file mode 100644
index 8d457220ea9c5ae..000000000000000
--- a/llvm/test/Analysis/CostModel/ARM/unaligned_double_load.ll
+++ /dev/null
@@ -1,59 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
-; RUN: opt -passes="print<cost-model>" 2>&1 -disable-output -mtriple=thumbv6m-none-eabi < %s | FileCheck %s --check-prefix=CHECK-NOVEC
-; RUN: opt -passes="print<cost-model>" 2>&1 -disable-output -mtriple=thumbv7m-none-eabi -mcpu=cortex-m4 < %s | FileCheck %s --check-prefix=CHECK-FP
-
-define float @f(ptr %x) {
-; CHECK-NOVEC-LABEL: 'f'
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %a.0.copyload = load float, ptr %x, align 1
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float %a.0.copyload
-;
-; CHECK-FP-LABEL: 'f'
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %a.0.copyload = load float, ptr %x, align 1
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float %a.0.copyload
-;
-entry:
-  %a.0.copyload = load float, ptr %x, align 1
-  ret float %a.0.copyload
-}
-
-define float @ff(ptr %x, float %f) {
-; CHECK-NOVEC-LABEL: 'ff'
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store float %f, ptr %x, align 1
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float undef
-;
-; CHECK-FP-LABEL: 'ff'
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: store float %f, ptr %x, align 1
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret float undef
-;
-entry:
-  store float %f, ptr %x, align 1
-  ret float undef
-}
-
-define double @d(ptr %x) {
-; CHECK-NOVEC-LABEL: 'd'
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %a.0.copyload = load double, ptr %x, align 1
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double %a.0.copyload
-;
-; CHECK-FP-LABEL: 'd'
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %a.0.copyload = load double, ptr %x, align 1
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double %a.0.copyload
-;
-entry:
-  %a.0.copyload = load double, ptr %x, align 1
-  ret double %a.0.copyload
-}
-
-define double @dd(ptr %x, double %f) {
-; CHECK-NOVEC-LABEL: 'dd'
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: store double %f, ptr %x, align 1
-; CHECK-NOVEC-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double undef
-;
-; CHECK-FP-LABEL: 'dd'
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: store double %f, ptr %x, align 1
-; CHECK-FP-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: ret double undef
-;
-entry:
-  store double %f, ptr %x, align 1
-  ret double undef
-}

>From 449b1c1fa5c8825f10281f31e958523341c0bb57 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 19 Oct 2023 11:32:30 +0100
Subject: [PATCH 03/12] fixup! Use DAGToDAG approach

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 37 ++--------
 .../aarch64-sme-intrinsics/acle_sme_ldr.c     | 58 ++++++----------
 .../aarch64-sme-intrinsics/acle_sme_str.c     | 52 +++++++-------
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    | 54 +++++++++++++++
 llvm/lib/Target/AArch64/SMEInstrFormats.td    |  4 +-
 .../CodeGen/AArch64/sme-intrinsics-loads.ll   | 68 ++++++++++++-------
 .../CodeGen/AArch64/sme-intrinsics-stores.ll  | 54 +++++++++++++--
 7 files changed, 191 insertions(+), 136 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 8444aea8c8ac4b6..ec1c070c5bbd423 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -9815,11 +9815,6 @@ Value *CodeGenFunction::EmitSVEMaskedStore(const CallExpr *E,
   return Store;
 }
 
-Value *CodeGenFunction::EmitTileslice(Value *Offset, Value *Base) {
-  llvm::Value *CastOffset = Builder.CreateIntCast(Offset, Int64Ty, false);
-  return Builder.CreateAdd(Base, CastOffset, "tileslice");
-}
-
 Value *CodeGenFunction::EmitSMELd1St1(const SVETypeFlags &TypeFlags,
                                       SmallVectorImpl<Value *> &Ops,
                                       unsigned IntID) {
@@ -9875,34 +9870,10 @@ Value *CodeGenFunction::EmitSMEZero(const SVETypeFlags &TypeFlags,
 Value *CodeGenFunction::EmitSMELdrStr(const SVETypeFlags &TypeFlags,
                                       SmallVectorImpl<Value *> &Ops,
                                       unsigned IntID) {
-  if (Ops.size() == 2) {
-    // Intrinsics without a vecnum also use this function, so just provide 0
-    Ops.push_back(Ops[1]);
-    Ops[1] = Builder.getInt32(0);
-  } else {
-    int Imm = -1;
-    if (ConstantInt* C = dyn_cast<ConstantInt>(Ops[2]))
-      if (C->getZExtValue() <= 15)
-          Imm = C->getZExtValue();
-
-    if (Imm != -1) {
-      Ops[2] = Ops[1];
-      Ops[1] = Builder.getInt32(Imm);
-    } else {
-      Function *Cntsb = CGM.getIntrinsic(Intrinsic::aarch64_sme_cntsb);
-      llvm::Value *CntsbCall = Builder.CreateCall(Cntsb, {}, "svlb");
-
-      llvm::Value *VecNum = Ops[2];
-      llvm::Value *MulVL = Builder.CreateMul(
-          CntsbCall,
-          VecNum,
-          "mulvl");
-
-      Ops[2] = Builder.CreateGEP(Int8Ty, Ops[1], MulVL);
-      Ops[1] = Builder.getInt32(0);
-      Ops[0] = Builder.CreateIntCast(EmitTileslice(Ops[0], VecNum), Int32Ty, false);
-    }
-   }
+  if (Ops.size() == 2)
+    Ops.push_back(Builder.getInt32(0));
+  else
+    Ops[2] = Builder.CreateIntCast(Ops[2], Int32Ty, true);
   Function *F = CGM.getIntrinsic(IntID, {});
   return Builder.CreateCall(F, Ops);
 }
diff --git a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c
index 8e07cf1d11c19b2..9af0778e89c5ec0 100644
--- a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c
+++ b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_ldr.c
@@ -9,7 +9,7 @@
 // CHECK-C-LABEL: @test_svldr_vnum_za(
 // CHECK-CXX-LABEL: @_Z18test_svldr_vnum_zajPKv(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
 // CHECK-NEXT:    ret void
 //
 void test_svldr_vnum_za(uint32_t slice_base, const void *ptr) {
@@ -19,60 +19,40 @@ void test_svldr_vnum_za(uint32_t slice_base, const void *ptr) {
 // CHECK-C-LABEL: @test_svldr_vnum_za_1(
 // CHECK-CXX-LABEL: @_Z20test_svldr_vnum_za_1jPKv(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], i32 15, ptr [[PTR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 15)
 // CHECK-NEXT:    ret void
 //
 void test_svldr_vnum_za_1(uint32_t slice_base, const void *ptr) {
   svldr_vnum_za(slice_base, ptr, 15);
 }
 
-// CHECK-C-LABEL: @test_svldr_vnum_za_var(
-// CHECK-CXX-LABEL: @_Z22test_svldr_vnum_za_varjPKvm(
-// CHECK-NEXT:  entry:
-// CHECK-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
-// CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
-// CHECK-NEXT:    [[TMP2:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
-// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[TMP2]], i32 0, ptr [[TMP0]])
-// CHECK-NEXT:    ret void
-//
-void test_svldr_vnum_za_var(uint32_t slice_base, const void *ptr, uint64_t vnum) {
-  svldr_vnum_za(slice_base, ptr, vnum);
-}
-
 // CHECK-C-LABEL: @test_svldr_za(
 // CHECK-CXX-LABEL: @_Z13test_svldr_zajPKv(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
 // CHECK-NEXT:    ret void
 //
 void test_svldr_za(uint32_t slice_base, const void *ptr) {
   svldr_za(slice_base, ptr);
 }
 
-// CHECK-C-LABEL: define dso_local void @test_svldr_vnum_za_var(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-C-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
-// CHECK-C-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-C-NEXT:    [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
-// CHECK-C-NEXT:    [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-C-NEXT:    ret void
-//
-// CHECK-CXX-LABEL: define dso_local void @_Z22test_svldr_vnum_za_varjPKvl(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-CXX-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
-// CHECK-CXX-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-CXX-NEXT:    [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
-// CHECK-CXX-NEXT:    [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-CXX-NEXT:    ret void
+// CHECK-C-LABEL: @test_svldr_vnum_za_var(
+// CHECK-CXX-LABEL: @_Z22test_svldr_vnum_za_varjPKvl(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = trunc i64 [[VNUM:%.*]] to i32
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 [[TMP0:%.*]])
+// CHECK-NEXT:    ret void
 //
 void test_svldr_vnum_za_var(uint32_t slice_base, const void *ptr, int64_t vnum) {
   svldr_vnum_za(slice_base, ptr, vnum);
 }
+
+// CHECK-C-LABEL: @test_svldr_vnum_za_2(
+// CHECK-CXX-LABEL: @_Z20test_svldr_vnum_za_2jPKv(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.ldr(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 16)
+// CHECK-NEXT:    ret void
+//
+void test_svldr_vnum_za_2(uint32_t slice_base, const void *ptr) {
+  svldr_vnum_za(slice_base, ptr, 16);
+}
diff --git a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c
index 532f570b6aaa444..baadfc18563a005 100644
--- a/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c
+++ b/clang/test/CodeGen/aarch64-sme-intrinsics/acle_sme_str.c
@@ -9,56 +9,50 @@
 // CHECK-C-LABEL: @test_svstr_vnum_za(
 // CHECK-CXX-LABEL: @_Z18test_svstr_vnum_zajPv(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
 // CHECK-NEXT:    ret void
 //
 void test_svstr_vnum_za(uint32_t slice_base, void *ptr) {
   svstr_vnum_za(slice_base, ptr, 0);
 }
 
-// CHECK-C-LABEL: define dso_local void @test_svstr_vnum_za_1(
-// CHECK-CXX-LABEL: define dso_local void @_Z20test_svstr_vnum_za_1jPv(
+// CHECK-C-LABEL: @test_svstr_vnum_za_1(
+// CHECK-CXX-LABEL: @_Z20test_svstr_vnum_za_1jPv(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], i32 15, ptr [[PTR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 15)
 // CHECK-NEXT:    ret void
 //
 void test_svstr_vnum_za_1(uint32_t slice_base, void *ptr) {
   svstr_vnum_za(slice_base, ptr, 15);
 }
 
-// CHECK-C-LABEL: define dso_local void @test_svstr_za(
-// CHECK-CXX-LABEL: define dso_local void @_Z13test_svstr_zajPv(
-// CHECK-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-C-LABEL: @test_svstr_za(
+// CHECK-CXX-LABEL: @_Z13test_svstr_zajPv(
 // CHECK-NEXT:  entry:
-// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE]], i32 0, ptr [[PTR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 0)
 // CHECK-NEXT:    ret void
 //
 void test_svstr_za(uint32_t slice_base, void *ptr) {
   svstr_za(slice_base, ptr);
 }
 
-// CHECK-C-LABEL: define dso_local void @test_svstr_vnum_za_var(
-// CHECK-C-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-C-NEXT:  entry:
-// CHECK-C-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-C-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
-// CHECK-C-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-C-NEXT:    [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
-// CHECK-C-NEXT:    [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
-// CHECK-C-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-C-NEXT:    ret void
-//
-// CHECK-CXX-LABEL: define dso_local void @_Z22test_svstr_vnum_za_varjPvl(
-// CHECK-CXX-SAME: i32 noundef [[SLICE_BASE:%.*]], ptr noundef [[PTR:%.*]], i64 noundef [[VNUM:%.*]]) local_unnamed_addr #[[ATTR0]] {
-// CHECK-CXX-NEXT:  entry:
-// CHECK-CXX-NEXT:    [[SVLB:%.*]] = tail call i64 @llvm.aarch64.sme.cntsb()
-// CHECK-CXX-NEXT:    [[MULVL:%.*]] = mul i64 [[SVLB]], [[VNUM]]
-// CHECK-CXX-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[MULVL]]
-// CHECK-CXX-NEXT:    [[TMP1:%.*]] = trunc i64 [[VNUM]] to i32
-// CHECK-CXX-NEXT:    [[TILESLICE:%.*]] = add i32 [[TMP1]], [[SLICE_BASE]]
-// CHECK-CXX-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[TILESLICE]], ptr [[TMP0]])
-// CHECK-CXX-NEXT:    ret void
+// CHECK-C-LABEL: @test_svstr_vnum_za_var(
+// CHECK-CXX-LABEL: @_Z22test_svstr_vnum_za_varjPvl(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[TMP0:%.*]] = trunc i64 [[VNUM:%.*]] to i32
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 [[TMP0:%.*]])
+// CHECK-NEXT:    ret void
 //
 void test_svstr_vnum_za_var(uint32_t slice_base, void *ptr, int64_t vnum) {
   svstr_vnum_za(slice_base, ptr, vnum);
 }
+
+// CHECK-C-LABEL: @test_svstr_vnum_za_2(
+// CHECK-CXX-LABEL: @_Z20test_svstr_vnum_za_2jPv(
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.str(i32 [[SLICE_BASE:%.*]], ptr [[PTR:%.*]], i32 16)
+// CHECK-NEXT:    ret void
+//
+void test_svstr_vnum_za_2(uint32_t slice_base, void *ptr) {
+  svstr_vnum_za(slice_base, ptr, 16);
+}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index abfe14e52509d58..6efa62a38def8d5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -383,6 +383,7 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   void SelectPExtPair(SDNode *N, unsigned Opc);
   void SelectWhilePair(SDNode *N, unsigned Opc);
   void SelectCVTIntrinsic(SDNode *N, unsigned NumVecs, unsigned Opcode);
+  void SelectSMELdrStrZA(SDNode *N, bool IsLoad);
   void SelectClamp(SDNode *N, unsigned NumVecs, unsigned Opcode);
   void SelectUnaryMultiIntrinsic(SDNode *N, unsigned NumOutVecs,
                                  bool IsTupleInput, unsigned Opc);
@@ -1749,6 +1750,54 @@ void AArch64DAGToDAGISel::SelectCVTIntrinsic(SDNode *N, unsigned NumVecs,
   CurDAG->RemoveDeadNode(N);
 }
 
+void AArch64DAGToDAGISel::SelectSMELdrStrZA(SDNode *N, bool IsLoad) {
+  // Lower an SME LDR/STR ZA intrinsic to LDR_ZA_PSEUDO or STR_ZA.
+  // If the vector select parameter is an immediate in the range 0-15 then we
+  // can emit it directly into the instruction as it's a legal operand.
+  // Otherwise we must emit 0 as the vector select operand and modify the base
+  // register instead.
+  SDLoc DL(N);
+
+  SDValue VecNum = N->getOperand(4), Base = N->getOperand(3),
+          TileSlice = N->getOperand(2);
+  int Imm = -1;
+  if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum))
+    Imm = ImmNode->getZExtValue();
+
+  if (Imm >= 0 && Imm <= 15) {
+    // 0-15 is a legal immediate so just pass it directly as a TargetConstant
+    VecNum = CurDAG->getTargetConstant(Imm, DL, MVT::i32);
+  } else {
+    // Get the vector length that will be multiplied by vnum
+    auto SVL = SDValue(
+        CurDAG->getMachineNode(AArch64::RDSVLI_XI, DL, MVT::i64,
+                               CurDAG->getTargetConstant(1, DL, MVT::i32)),
+        0);
+
+    // Multiply SVL and vnum then add it to the base register
+    if (VecNum.getValueType() == MVT::i32)
+      VecNum = Widen(CurDAG, VecNum);
+    SDValue AddOps[] = {SVL, VecNum, Base};
+    auto Add = SDValue(
+        CurDAG->getMachineNode(AArch64::MADDXrrr, DL, MVT::i64, AddOps), 0);
+
+    // The base register has been modified to take vnum into account so just
+    // pass 0
+    VecNum = CurDAG->getTargetConstant(0, DL, MVT::i32);
+    Base = Add;
+  }
+
+  SmallVector<SDValue, 6> Ops = {TileSlice, VecNum, Base};
+  if (!IsLoad) {
+    Ops.insert(Ops.begin(), CurDAG->getRegister(AArch64::ZA, MVT::Other));
+    Ops.push_back(VecNum);
+  }
+  auto LdrStr =
+      CurDAG->getMachineNode(IsLoad ? AArch64::LDR_ZA_PSEUDO : AArch64::STR_ZA,
+                             DL, N->getValueType(0), Ops);
+  ReplaceNode(N, LdrStr);
+}
+
 void AArch64DAGToDAGISel::SelectDestructiveMultiIntrinsic(SDNode *N,
                                                           unsigned NumVecs,
                                                           bool IsZmMulti,
@@ -5671,6 +5720,11 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
     switch (IntNo) {
     default:
       break;
+    case Intrinsic::aarch64_sme_str:
+    case Intrinsic::aarch64_sme_ldr: {
+      SelectSMELdrStrZA(Node, IntNo == Intrinsic::aarch64_sme_ldr);
+      return;
+    }
     case Intrinsic::aarch64_neon_st1x2: {
       if (VT == MVT::v8i8) {
         SelectStore(Node, 2, AArch64::ST1Twov8b);
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 143931a001a4ba1..5e2b37ce8d37db8 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -795,7 +795,7 @@ multiclass sme_spill<string opcodestr> {
                   (!cast<Instruction>(NAME) MatrixOp:$ZAt,
                    MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>;
   // base
-  def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_15:$imm, GPR64sp:$base),
+  def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base, sme_elm_idx0_15:$imm),
             (!cast<Instruction>(NAME) ZA, $idx, $imm, $base, 0)>;
 }
 
@@ -814,7 +814,7 @@ multiclass sme_fill<string opcodestr> {
     let mayLoad = 1;
   }
   // base
-  def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_15:$imm, GPR64sp:$base),
+  def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base, sme_elm_idx0_15:$imm),
             (!cast<Instruction>(NAME # _PSEUDO) $idx, $imm, $base)>;
 }
 
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
index f5d25a3229a7f82..340b54cc0d2731f 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
@@ -252,28 +252,10 @@ define void @ldr(ptr %ptr) {
 ; CHECK-NEXT:    mov w12, wzr
 ; CHECK-NEXT:    ldr za[w12, 0], [x0]
 ; CHECK-NEXT:    ret
-  call void @llvm.aarch64.sme.ldr(i32 0, i32 0, ptr %ptr)
+  call void @llvm.aarch64.sme.ldr(i32 0, ptr %ptr, i32 0)
   ret void;
 }
 
-define void @ldr_vnum(i32 %tile_slice, ptr %ptr, i64 %vnum) {
-; CHECK-LABEL: ldr_vnum:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    add w12, w2, w0
-; CHECK-NEXT:    madd x8, x8, x2, x1
-; CHECK-NEXT:    ldr za[w12, 0], [x8]
-; CHECK-NEXT:    ret
-entry:
-  %svlb = tail call i64 @llvm.aarch64.sme.cntsb()
-  %mulvl = mul i64 %svlb, %vnum
-  %0 = getelementptr i8, ptr %ptr, i64 %mulvl
-  %1 = trunc i64 %vnum to i32
-  %2 = add i32 %1, %tile_slice
-  tail call void @llvm.aarch64.sme.ldr(i32 %2, i32 0, ptr %0)
-  ret void
-}
-
 define void @ldr_with_off_15(ptr %ptr) {
 ; CHECK-LABEL: ldr_with_off_15:
 ; CHECK:       // %bb.0:
@@ -282,7 +264,7 @@ define void @ldr_with_off_15(ptr %ptr) {
 ; CHECK-NEXT:    ldr za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   %base = getelementptr i8, ptr %ptr, i64 15
-  call void @llvm.aarch64.sme.ldr(i32 15, i32 0, ptr %base)
+  call void @llvm.aarch64.sme.ldr(i32 15, ptr %base, i32 0)
   ret void;
 }
 
@@ -296,7 +278,7 @@ define void @ldr_with_off_15mulvl(ptr %ptr) {
   %vscale = call i64 @llvm.vscale.i64()
   %mulvl = mul i64 %vscale, 240
   %base = getelementptr i8, ptr %ptr, i64 %mulvl
-  call void @llvm.aarch64.sme.ldr(i32 15, i32 0, ptr %base)
+  call void @llvm.aarch64.sme.ldr(i32 15, ptr %base, i32 0)
   ret void;
 }
 
@@ -310,7 +292,42 @@ define void @ldr_with_off_16mulvl(ptr %ptr) {
   %vscale = call i64 @llvm.vscale.i64()
   %mulvl = mul i64 %vscale, 256
   %base = getelementptr i8, ptr %ptr, i64 %mulvl
-  call void @llvm.aarch64.sme.ldr(i32 16, i32 0, ptr %base)
+  call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 0)
+  ret void;
+}
+
+define void @ldr_with_off_var(ptr %base, i32 %off) {
+; CHECK-LABEL: ldr_with_off_var:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov w12, #16 // =0x10
+; CHECK-NEXT:    madd x8, x8, x1, x0
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 %off)
+  ret void;
+}
+
+define void @ldr_with_off_15imm(ptr %base) {
+; CHECK-LABEL: ldr_with_off_15imm:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, #16 // =0x10
+; CHECK-NEXT:    ldr za[w12, 15], [x0, #15, mul vl]
+; CHECK-NEXT:    ret
+  call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 15)
+  ret void;
+}
+
+define void @ldr_with_off_16imm(ptr %base) {
+; CHECK-LABEL: ldr_with_off_16imm:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov w12, #16 // =0x10
+; CHECK-NEXT:    madd x8, x8, x12, x0
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 16)
   ret void;
 }
 
@@ -320,13 +337,13 @@ define void @test_ld1_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src,
 ; CHECK-LABEL: test_ld1_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB15_1: // %for.body
+; CHECK-NEXT:  .LBB17_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 0]}, p0/z, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 1]}, p0/z, [x0]
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 2]}, p0/z, [x0]
-; CHECK-NEXT:    b.ne .LBB15_1
+; CHECK-NEXT:    b.ne .LBB17_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:
@@ -359,6 +376,5 @@ declare void @llvm.aarch64.sme.ld1w.vert(<vscale x 4 x i1>, ptr, i32, i32)
 declare void @llvm.aarch64.sme.ld1d.vert(<vscale x 2 x i1>, ptr, i32, i32)
 declare void @llvm.aarch64.sme.ld1q.vert(<vscale x 1 x i1>, ptr, i32, i32)
 
-declare void @llvm.aarch64.sme.ldr(i32, i32, ptr)
+declare void @llvm.aarch64.sme.ldr(i32, ptr, i32)
 declare i64 @llvm.vscale.i64()
-declare i64 @llvm.aarch64.sme.cntsb()
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
index 2bb9c3d05b9da5c..b55c2bc78b0fcf0 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
@@ -252,7 +252,7 @@ define void @str(ptr %ptr) {
 ; CHECK-NEXT:    mov w12, wzr
 ; CHECK-NEXT:    str za[w12, 0], [x0]
 ; CHECK-NEXT:    ret
-  call void @llvm.aarch64.sme.str(i32 0, ptr %ptr)
+  call void @llvm.aarch64.sme.str(i32 0, ptr %ptr, i32 0)
   ret void;
 }
 
@@ -264,7 +264,7 @@ define void @str_with_off_15(ptr %ptr) {
 ; CHECK-NEXT:    str za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   %base = getelementptr i8, ptr %ptr, i64 15
-  call void @llvm.aarch64.sme.str(i32 15, ptr %base)
+  call void @llvm.aarch64.sme.str(i32 15, ptr %base, i32 0)
   ret void;
 }
 
@@ -278,7 +278,7 @@ define void @str_with_off_15mulvl(ptr %ptr) {
   %vscale = call i64 @llvm.vscale.i64()
   %mulvl = mul i64 %vscale, 240
   %base = getelementptr i8, ptr %ptr, i64 %mulvl
-  call void @llvm.aarch64.sme.str(i32 15, ptr %base)
+  call void @llvm.aarch64.sme.str(i32 15, ptr %base, i32 0)
   ret void;
 }
 
@@ -292,7 +292,47 @@ define void @str_with_off_16mulvl(ptr %ptr) {
   %vscale = call i64 @llvm.vscale.i64()
   %mulvl = mul i64 %vscale, 256
   %base = getelementptr i8, ptr %ptr, i64 %mulvl
-  call void @llvm.aarch64.sme.str(i32 16, ptr %base)
+  call void @llvm.aarch64.sme.str(i32 16, ptr %base, i32 0)
+  ret void;
+}
+
+define void @str_with_off_var(ptr %base, i32 %off) {
+; CHECK-LABEL: str_with_off_var:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov w12, #16 // =0x10
+; CHECK-NEXT:    madd x8, x8, x1, x0
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  call void @llvm.aarch64.sme.str(i32 16, ptr %base, i32 %off)
+  ret void;
+}
+
+define void @str_with_off_15imm(ptr %ptr) {
+; CHECK-LABEL: str_with_off_15imm:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w12, #15 // =0xf
+; CHECK-NEXT:    add x8, x0, #15
+; CHECK-NEXT:    str za[w12, 15], [x8, #15, mul vl]
+; CHECK-NEXT:    ret
+  %base = getelementptr i8, ptr %ptr, i64 15
+  call void @llvm.aarch64.sme.str(i32 15, ptr %base, i32 15)
+  ret void;
+}
+
+define void @str_with_off_16imm(ptr %ptr) {
+; CHECK-LABEL: str_with_off_16imm:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov w9, #16 // =0x10
+; CHECK-NEXT:    add x10, x0, #15
+; CHECK-NEXT:    madd x8, x8, x9, x10
+; CHECK-NEXT:    mov w12, #15 // =0xf
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    ret
+  %base = getelementptr i8, ptr %ptr, i64 15
+  call void @llvm.aarch64.sme.str(i32 15, ptr %base, i32 16)
   ret void;
 }
 
@@ -302,13 +342,13 @@ define void @test_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src, i32
 ; CHECK-LABEL: test_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB14_1: // %for.body
+; CHECK-NEXT:  .LBB17_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 0]}, p0, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 1]}, p0, [x0]
 ; CHECK-NEXT:    st1w {za0h.s[w12, 2]}, p0, [x0]
-; CHECK-NEXT:    b.ne .LBB14_1
+; CHECK-NEXT:    b.ne .LBB17_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:
@@ -340,5 +380,5 @@ declare void @llvm.aarch64.sme.st1w.vert(<vscale x 4 x i1>, ptr, i32, i32)
 declare void @llvm.aarch64.sme.st1d.vert(<vscale x 2 x i1>, ptr, i32, i32)
 declare void @llvm.aarch64.sme.st1q.vert(<vscale x 1 x i1>, ptr, i32, i32)
 
-declare void @llvm.aarch64.sme.str(i32, ptr)
+declare void @llvm.aarch64.sme.str(i32, ptr, i32)
 declare i64 @llvm.vscale.i64()

>From 5e03b2e06c02010913d6e954cd7840b214a16e7a Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Fri, 3 Nov 2023 09:47:50 +0000
Subject: [PATCH 04/12] fixup! lower in ISelLowering instead

---
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    | 54 --------------
 .../Target/AArch64/AArch64ISelLowering.cpp    | 72 +++++++++++++++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  4 ++
 llvm/lib/Target/AArch64/SMEInstrFormats.td    | 21 +++---
 .../CodeGen/AArch64/sme-intrinsics-loads.ll   | 61 +++++++++++++---
 .../CodeGen/AArch64/sme-intrinsics-stores.ll  | 63 +++++++++++++---
 6 files changed, 193 insertions(+), 82 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 6efa62a38def8d5..abfe14e52509d58 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -383,7 +383,6 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   void SelectPExtPair(SDNode *N, unsigned Opc);
   void SelectWhilePair(SDNode *N, unsigned Opc);
   void SelectCVTIntrinsic(SDNode *N, unsigned NumVecs, unsigned Opcode);
-  void SelectSMELdrStrZA(SDNode *N, bool IsLoad);
   void SelectClamp(SDNode *N, unsigned NumVecs, unsigned Opcode);
   void SelectUnaryMultiIntrinsic(SDNode *N, unsigned NumOutVecs,
                                  bool IsTupleInput, unsigned Opc);
@@ -1750,54 +1749,6 @@ void AArch64DAGToDAGISel::SelectCVTIntrinsic(SDNode *N, unsigned NumVecs,
   CurDAG->RemoveDeadNode(N);
 }
 
-void AArch64DAGToDAGISel::SelectSMELdrStrZA(SDNode *N, bool IsLoad) {
-  // Lower an SME LDR/STR ZA intrinsic to LDR_ZA_PSEUDO or STR_ZA.
-  // If the vector select parameter is an immediate in the range 0-15 then we
-  // can emit it directly into the instruction as it's a legal operand.
-  // Otherwise we must emit 0 as the vector select operand and modify the base
-  // register instead.
-  SDLoc DL(N);
-
-  SDValue VecNum = N->getOperand(4), Base = N->getOperand(3),
-          TileSlice = N->getOperand(2);
-  int Imm = -1;
-  if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum))
-    Imm = ImmNode->getZExtValue();
-
-  if (Imm >= 0 && Imm <= 15) {
-    // 0-15 is a legal immediate so just pass it directly as a TargetConstant
-    VecNum = CurDAG->getTargetConstant(Imm, DL, MVT::i32);
-  } else {
-    // Get the vector length that will be multiplied by vnum
-    auto SVL = SDValue(
-        CurDAG->getMachineNode(AArch64::RDSVLI_XI, DL, MVT::i64,
-                               CurDAG->getTargetConstant(1, DL, MVT::i32)),
-        0);
-
-    // Multiply SVL and vnum then add it to the base register
-    if (VecNum.getValueType() == MVT::i32)
-      VecNum = Widen(CurDAG, VecNum);
-    SDValue AddOps[] = {SVL, VecNum, Base};
-    auto Add = SDValue(
-        CurDAG->getMachineNode(AArch64::MADDXrrr, DL, MVT::i64, AddOps), 0);
-
-    // The base register has been modified to take vnum into account so just
-    // pass 0
-    VecNum = CurDAG->getTargetConstant(0, DL, MVT::i32);
-    Base = Add;
-  }
-
-  SmallVector<SDValue, 6> Ops = {TileSlice, VecNum, Base};
-  if (!IsLoad) {
-    Ops.insert(Ops.begin(), CurDAG->getRegister(AArch64::ZA, MVT::Other));
-    Ops.push_back(VecNum);
-  }
-  auto LdrStr =
-      CurDAG->getMachineNode(IsLoad ? AArch64::LDR_ZA_PSEUDO : AArch64::STR_ZA,
-                             DL, N->getValueType(0), Ops);
-  ReplaceNode(N, LdrStr);
-}
-
 void AArch64DAGToDAGISel::SelectDestructiveMultiIntrinsic(SDNode *N,
                                                           unsigned NumVecs,
                                                           bool IsZmMulti,
@@ -5720,11 +5671,6 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
     switch (IntNo) {
     default:
       break;
-    case Intrinsic::aarch64_sme_str:
-    case Intrinsic::aarch64_sme_ldr: {
-      SelectSMELdrStrZA(Node, IntNo == Intrinsic::aarch64_sme_ldr);
-      return;
-    }
     case Intrinsic::aarch64_neon_st1x2: {
       if (VT == MVT::v8i8) {
         SelectStore(Node, 2, AArch64::ST1Twov8b);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9ff6d6f0f565edb..642707054ec0863 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2406,6 +2406,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::FCMP)
     MAKE_CASE(AArch64ISD::STRICT_FCMP)
     MAKE_CASE(AArch64ISD::STRICT_FCMPE)
+    MAKE_CASE(AArch64ISD::SME_ZA_LDR)
+    MAKE_CASE(AArch64ISD::SME_ZA_STR)
     MAKE_CASE(AArch64ISD::DUP)
     MAKE_CASE(AArch64ISD::DUPLANE8)
     MAKE_CASE(AArch64ISD::DUPLANE16)
@@ -4850,6 +4852,72 @@ SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
                      Mask);
 }
 
+SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
+  // Lower an SME LDR/STR ZA intrinsic to LDR_ZA_PSEUDO or STR_ZA.
+  // If the vector number is an immediate between 0 and 15 inclusive then we can
+  // put that directly into the immediate field of the instruction. If it's
+  // outside of that range then we modify the base and slice by the greatest
+  // multiple of 15 smaller than that number and put the remainder in the
+  // instruction field. If it's not an immediate then we modify the base and
+  // slice registers by that number and put 0 in the instruction.
+  SDLoc DL(N);
+
+  SDValue TileSlice = N->getOperand(2);
+  SDValue Base = N->getOperand(3);
+  SDValue VecNum = N->getOperand(4);
+  SDValue Remainder = DAG.getTargetConstant(0, DL, MVT::i32);
+
+  // true if the base and slice registers need to me modified
+  bool NeedsAdd = true;
+  if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum)) {
+    int Imm = ImmNode->getSExtValue();
+    if (Imm >= 0 && Imm <= 15) {
+      Remainder = DAG.getTargetConstant(Imm, DL, MVT::i32);
+      NeedsAdd = false;
+    } else {
+      Remainder = DAG.getTargetConstant(Imm % 15, DL, MVT::i32);
+      NeedsAdd = true;
+      VecNum = DAG.getConstant(Imm - (Imm % 15), DL, MVT::i32);
+    }
+  } else if (VecNum.getOpcode() == ISD::ADD) {
+    // If the vnum is an add, we can fold that add into the instruction if the
+    // operand is an immediate in range
+    if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum.getOperand(1))) {
+      int Imm = ImmNode->getSExtValue();
+      if (Imm >= 0 && Imm <= 15) {
+        VecNum = VecNum.getOperand(0);
+        Remainder = DAG.getTargetConstant(Imm, DL, MVT::i32);
+        NeedsAdd = true;
+      }
+    }
+  }
+  if (NeedsAdd) {
+    // Get the vector length that will be multiplied by vnum
+    auto SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
+                           DAG.getConstant(1, DL, MVT::i32));
+
+    // Multiply SVL and vnum then add it to the base
+    // Just add vnum to the tileslice
+    SDValue BaseMulOps[] = {
+        SVL, VecNum.getValueType() == MVT::i32
+                 ? DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VecNum)
+                 : VecNum};
+    SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, BaseMulOps);
+
+    SDValue BaseAddOps[] = {Base, Mul};
+    Base = DAG.getNode(ISD::ADD, DL, MVT::i64, BaseAddOps);
+
+    SDValue SliceAddOps[] = {TileSlice, VecNum};
+    TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, SliceAddOps);
+  }
+
+  SmallVector<SDValue, 4> Ops = {N.getOperand(0), TileSlice, Base, Remainder};
+  auto LdrStr =
+      DAG.getNode(IsLoad ? AArch64ISD::SME_ZA_LDR : AArch64ISD::SME_ZA_STR, DL,
+                  MVT::Other, Ops);
+  return LdrStr;
+}
+
 SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
                                                    SelectionDAG &DAG) const {
   unsigned IntNo = Op.getConstantOperandVal(1);
@@ -4873,6 +4941,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
     return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Chain,
                        DAG.getTargetConstant(PrfOp, DL, MVT::i32), Addr);
   }
+  case Intrinsic::aarch64_sme_str:
+  case Intrinsic::aarch64_sme_ldr: {
+    return LowerSMELdrStr(Op, DAG, IntNo == Intrinsic::aarch64_sme_ldr);
+  }
   case Intrinsic::aarch64_sme_za_enable:
     return DAG.getNode(
         AArch64ISD::SMSTART, DL, MVT::Other,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index b638084f98dadb8..b690b996d8cddff 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -442,6 +442,10 @@ enum NodeType : unsigned {
   STRICT_FCMP = ISD::FIRST_TARGET_STRICTFP_OPCODE,
   STRICT_FCMPE,
 
+  // SME ZA loads and stores
+  SME_ZA_LDR,
+  SME_ZA_STR,
+
   // NEON Load/Store with post-increment base updates
   LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE,
   LD3post,
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 5e2b37ce8d37db8..635f5e20cef2e5e 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -34,6 +34,12 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0,  4>", []>;
 
 def am_sme_indexed_b4 :ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0,15>", [], [SDNPWantRoot]>;
 
+def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
+def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
+                             [SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
+def AArch64SMEStr : SDNode<"AArch64ISD::SME_ZA_STR", SDTZALoadStore,
+                             [SDNPHasChain, SDNPSideEffect, SDNPMayStore]>;
+
 //===----------------------------------------------------------------------===//
 // SME Pseudo Classes
 //===----------------------------------------------------------------------===//
@@ -780,23 +786,23 @@ class sme_spill_inst<string opcodestr>
     : sme_spill_fill_base<0b1, (outs),
                           (ins MatrixOp:$ZAt, MatrixIndexGPR32Op12_15:$Rv,
                                sme_elm_idx0_15:$imm4, GPR64sp:$Rn,
-                               imm0_15:$offset),
+                               imm32_0_15:$offset),
                           opcodestr>;
 let mayLoad = 1 in
 class sme_fill_inst<string opcodestr>
     : sme_spill_fill_base<0b0, (outs MatrixOp:$ZAt),
                           (ins MatrixIndexGPR32Op12_15:$Rv,
                                sme_elm_idx0_15:$imm4, GPR64sp:$Rn,
-                               imm0_15:$offset),
+                               imm32_0_15:$offset),
                           opcodestr>;
 multiclass sme_spill<string opcodestr> {
   def NAME : sme_spill_inst<opcodestr>;
   def : InstAlias<opcodestr # "\t$ZAt[$Rv, $imm4], [$Rn]",
                   (!cast<Instruction>(NAME) MatrixOp:$ZAt,
                    MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_15:$imm4, GPR64sp:$Rn, 0), 1>;
-  // base
-  def : Pat<(int_aarch64_sme_str MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base, sme_elm_idx0_15:$imm),
-            (!cast<Instruction>(NAME) ZA, $idx, $imm, $base, 0)>;
+
+  def : Pat<(AArch64SMEStr (i32 MatrixIndexGPR32Op12_15:$slice), (i64 GPR64sp:$base), (i32 sme_elm_idx0_15:$imm)),
+          (!cast<Instruction>(NAME) ZA, MatrixIndexGPR32Op12_15:$slice, sme_elm_idx0_15:$imm, GPR64sp:$base, imm32_0_15:$imm)>;
 }
 
 multiclass sme_fill<string opcodestr> {
@@ -813,9 +819,8 @@ multiclass sme_fill<string opcodestr> {
     let usesCustomInserter = 1;
     let mayLoad = 1;
   }
-  // base
-  def : Pat<(int_aarch64_sme_ldr MatrixIndexGPR32Op12_15:$idx, GPR64sp:$base, sme_elm_idx0_15:$imm),
-            (!cast<Instruction>(NAME # _PSEUDO) $idx, $imm, $base)>;
+  def : Pat<(AArch64SMELdr MatrixIndexGPR32Op12_15:$slice, GPR64sp:$base, sme_elm_idx0_15:$imm),
+          (!cast<Instruction>(NAME # _PSEUDO) MatrixIndexGPR32Op12_15:$slice, sme_elm_idx0_15:$imm, GPR64sp:$base)>;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
index 340b54cc0d2731f..bcca2133984a6c8 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
@@ -299,10 +299,11 @@ define void @ldr_with_off_16mulvl(ptr %ptr) {
 define void @ldr_with_off_var(ptr %base, i32 %off) {
 ; CHECK-LABEL: ldr_with_off_var:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
-; CHECK-NEXT:    mov w12, #16 // =0x10
-; CHECK-NEXT:    madd x8, x8, x1, x0
+; CHECK-NEXT:    // kill: def $w2 killed $w2 def $x2
+; CHECK-NEXT:    sxtw x8, w2
+; CHECK-NEXT:    rdsvl x9, #1
+; CHECK-NEXT:    add w12, w0, w2
+; CHECK-NEXT:    madd x8, x9, x8, x1
 ; CHECK-NEXT:    ldr za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 %off)
@@ -323,27 +324,69 @@ define void @ldr_with_off_16imm(ptr %base) {
 ; CHECK-LABEL: ldr_with_off_16imm:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    mov w12, #16 // =0x10
-; CHECK-NEXT:    madd x8, x8, x12, x0
-; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    add w12, w0, #15
+; CHECK-NEXT:    sub x9, x1, x8
+; CHECK-NEXT:    add x8, x9, x8, lsl #4
+; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
   call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 16)
   ret void;
 }
 
+define void @ldr_with_off_many_imm(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: ldr_with_off_many_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov w12, w0
+; CHECK-NEXT:    ldr za[w12, 1], [x1, #1, mul vl]
+; CHECK-NEXT:    ldr za[w12, 2], [x1, #2, mul vl]
+; CHECK-NEXT:    ldr za[w12, 3], [x1, #3, mul vl]
+; CHECK-NEXT:    ldr za[w12, 4], [x1, #4, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 1)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 2)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 3)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 4)
+  ret void
+}
+
+define void @ldr_with_off_many_var(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+; CHECK-LABEL: ldr_with_off_many_var:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sxtw x8, w2
+; CHECK-NEXT:    rdsvl x9, #1
+; CHECK-NEXT:    add w12, w0, w2
+; CHECK-NEXT:    madd x8, x9, x8, x1
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    ldr za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    ldr za[w12, 3], [x8, #3, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  %0 = trunc i64 %vnum to i32
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %0)
+  %1 = add i32 %0, 1
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %1)
+  %2 = add i32 %0, 2
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %2)
+  %3 = add i32 %0, 3
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %3)
+  ret void
+}
+
 ; Ensure that the tile offset is sunk, given that this is likely to be an 'add'
 ; that's decomposed into a base + offset in ISel.
 define void @test_ld1_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src, i32 %base, i32 %N) {
 ; CHECK-LABEL: test_ld1_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB17_1: // %for.body
+; CHECK-NEXT:  .LBB19_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 0]}, p0/z, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 1]}, p0/z, [x0]
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 2]}, p0/z, [x0]
-; CHECK-NEXT:    b.ne .LBB17_1
+; CHECK-NEXT:    b.ne .LBB19_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
index b55c2bc78b0fcf0..f0239aacccada21 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
@@ -299,10 +299,11 @@ define void @str_with_off_16mulvl(ptr %ptr) {
 define void @str_with_off_var(ptr %base, i32 %off) {
 ; CHECK-LABEL: str_with_off_var:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
-; CHECK-NEXT:    mov w12, #16 // =0x10
-; CHECK-NEXT:    madd x8, x8, x1, x0
+; CHECK-NEXT:    // kill: def $w2 killed $w2 def $x2
+; CHECK-NEXT:    sxtw x8, w2
+; CHECK-NEXT:    rdsvl x9, #1
+; CHECK-NEXT:    add w12, w0, w2
+; CHECK-NEXT:    madd x8, x9, x8, x1
 ; CHECK-NEXT:    str za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   call void @llvm.aarch64.sme.str(i32 16, ptr %base, i32 %off)
@@ -325,30 +326,70 @@ define void @str_with_off_16imm(ptr %ptr) {
 ; CHECK-LABEL: str_with_off_16imm:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    mov w9, #16 // =0x10
-; CHECK-NEXT:    add x10, x0, #15
-; CHECK-NEXT:    madd x8, x8, x9, x10
-; CHECK-NEXT:    mov w12, #15 // =0xf
-; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    add w12, w0, #15
+; CHECK-NEXT:    sub x9, x1, x8
+; CHECK-NEXT:    add x8, x9, x8, lsl #4
+; CHECK-NEXT:    str za[w12, 1], [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
   %base = getelementptr i8, ptr %ptr, i64 15
   call void @llvm.aarch64.sme.str(i32 15, ptr %base, i32 16)
   ret void;
 }
 
+define void @str_with_off_many_imm(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+; CHECK-LABEL: str_with_off_many_imm:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov w12, w0
+; CHECK-NEXT:    str za[w12, 1], [x1, #1, mul vl]
+; CHECK-NEXT:    str za[w12, 2], [x1, #2, mul vl]
+; CHECK-NEXT:    str za[w12, 3], [x1, #3, mul vl]
+; CHECK-NEXT:    str za[w12, 4], [x1, #4, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 1)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 2)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 3)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 4)
+  ret void
+}
+
+define void @str_with_off_many_var(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+; CHECK-LABEL: str_with_off_many_var:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sxtw x8, w2
+; CHECK-NEXT:    rdsvl x9, #1
+; CHECK-NEXT:    add w12, w0, w2
+; CHECK-NEXT:    madd x8, x9, x8, x1
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    str za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    str za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    str za[w12, 3], [x8, #3, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  %0 = trunc i64 %vnum to i32
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %0)
+  %1 = add i32 %0, 1
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %1)
+  %2 = add i32 %0, 2
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %2)
+  %3 = add i32 %0, 3
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %3)
+  ret void
+}
+
 ; Ensure that the tile offset is sunk, given that this is likely to be an 'add'
 ; that's decomposed into a base + offset in ISel.
 define void @test_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src, i32 %base, i32 %N) {
 ; CHECK-LABEL: test_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB17_1: // %for.body
+; CHECK-NEXT:  .LBB19_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 0]}, p0, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 1]}, p0, [x0]
 ; CHECK-NEXT:    st1w {za0h.s[w12, 2]}, p0, [x0]
-; CHECK-NEXT:    b.ne .LBB17_1
+; CHECK-NEXT:    b.ne .LBB19_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:

>From a1e51b6ecc6543fe8641db916db5d6fdddd900f7 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 6 Nov 2023 10:34:27 +0000
Subject: [PATCH 05/12] fixup! Update check lines

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp    |  2 +-
 llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll  | 12 ++++++------
 llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll | 13 +++++++------
 3 files changed, 14 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 642707054ec0863..929b1e44f9af590 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4867,7 +4867,7 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
   SDValue VecNum = N->getOperand(4);
   SDValue Remainder = DAG.getTargetConstant(0, DL, MVT::i32);
 
-  // true if the base and slice registers need to me modified
+  // true if the base and slice registers need to be modified
   bool NeedsAdd = true;
   if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum)) {
     int Imm = ImmNode->getSExtValue();
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
index bcca2133984a6c8..09e7d7b4068ce17 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
@@ -299,11 +299,11 @@ define void @ldr_with_off_16mulvl(ptr %ptr) {
 define void @ldr_with_off_var(ptr %base, i32 %off) {
 ; CHECK-LABEL: ldr_with_off_var:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $w2 killed $w2 def $x2
-; CHECK-NEXT:    sxtw x8, w2
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    sxtw x8, w1
 ; CHECK-NEXT:    rdsvl x9, #1
-; CHECK-NEXT:    add w12, w0, w2
-; CHECK-NEXT:    madd x8, x9, x8, x1
+; CHECK-NEXT:    add w12, w1, #16
+; CHECK-NEXT:    madd x8, x9, x8, x0
 ; CHECK-NEXT:    ldr za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 %off)
@@ -324,8 +324,8 @@ define void @ldr_with_off_16imm(ptr %base) {
 ; CHECK-LABEL: ldr_with_off_16imm:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    add w12, w0, #15
-; CHECK-NEXT:    sub x9, x1, x8
+; CHECK-NEXT:    mov w12, #31 // =0x1f
+; CHECK-NEXT:    sub x9, x0, x8
 ; CHECK-NEXT:    add x8, x9, x8, lsl #4
 ; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
index f0239aacccada21..40327b80a1b96d7 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
@@ -299,11 +299,11 @@ define void @str_with_off_16mulvl(ptr %ptr) {
 define void @str_with_off_var(ptr %base, i32 %off) {
 ; CHECK-LABEL: str_with_off_var:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    // kill: def $w2 killed $w2 def $x2
-; CHECK-NEXT:    sxtw x8, w2
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    sxtw x8, w1
 ; CHECK-NEXT:    rdsvl x9, #1
-; CHECK-NEXT:    add w12, w0, w2
-; CHECK-NEXT:    madd x8, x9, x8, x1
+; CHECK-NEXT:    add w12, w1, #16
+; CHECK-NEXT:    madd x8, x9, x8, x0
 ; CHECK-NEXT:    str za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   call void @llvm.aarch64.sme.str(i32 16, ptr %base, i32 %off)
@@ -326,9 +326,10 @@ define void @str_with_off_16imm(ptr %ptr) {
 ; CHECK-LABEL: str_with_off_16imm:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    add w12, w0, #15
-; CHECK-NEXT:    sub x9, x1, x8
+; CHECK-NEXT:    mov w12, #30 // =0x1e
+; CHECK-NEXT:    sub x9, x0, x8
 ; CHECK-NEXT:    add x8, x9, x8, lsl #4
+; CHECK-NEXT:    add x8, x8, #15
 ; CHECK-NEXT:    str za[w12, 1], [x8, #1, mul vl]
 ; CHECK-NEXT:    ret
   %base = getelementptr i8, ptr %ptr, i64 15

>From 76b0c96be507833713e73383658930424c33173b Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 6 Nov 2023 11:46:34 +0000
Subject: [PATCH 06/12] fixup! Clean up node creation

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 19 +++++++------------
 1 file changed, 7 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 929b1e44f9af590..1e8558352cfd4df 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4897,21 +4897,16 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
                            DAG.getConstant(1, DL, MVT::i32));
 
     // Multiply SVL and vnum then add it to the base
+    SDValue Mul =
+        DAG.getNode(ISD::MUL, DL, MVT::i64,
+                    {SVL, DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VecNum)});
+    Base = DAG.getNode(ISD::ADD, DL, MVT::i64, {Base, Mul});
     // Just add vnum to the tileslice
-    SDValue BaseMulOps[] = {
-        SVL, VecNum.getValueType() == MVT::i32
-                 ? DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VecNum)
-                 : VecNum};
-    SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, BaseMulOps);
-
-    SDValue BaseAddOps[] = {Base, Mul};
-    Base = DAG.getNode(ISD::ADD, DL, MVT::i64, BaseAddOps);
-
-    SDValue SliceAddOps[] = {TileSlice, VecNum};
-    TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, SliceAddOps);
+    TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, {TileSlice, VecNum});
   }
 
-  SmallVector<SDValue, 4> Ops = {N.getOperand(0), TileSlice, Base, Remainder};
+  SmallVector<SDValue, 4> Ops = {/*Chain=*/N.getOperand(0), TileSlice, Base,
+                                 Remainder};
   auto LdrStr =
       DAG.getNode(IsLoad ? AArch64ISD::SME_ZA_LDR : AArch64ISD::SME_ZA_STR, DL,
                   MVT::Other, Ops);

>From 13cff2ed225e46c653ec5e7fe3ceb1e86f6f5a58 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 7 Nov 2023 10:12:23 +0000
Subject: [PATCH 07/12] fixup! modulo 16 instead of 15

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  4 +-
 .../CodeGen/AArch64/sme-intrinsics-loads.ll   | 90 ++++++++++++++++--
 .../CodeGen/AArch64/sme-intrinsics-stores.ll  | 92 +++++++++++++++++--
 3 files changed, 171 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e8558352cfd4df..515ab3f61c34686 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4875,9 +4875,9 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
       Remainder = DAG.getTargetConstant(Imm, DL, MVT::i32);
       NeedsAdd = false;
     } else {
-      Remainder = DAG.getTargetConstant(Imm % 15, DL, MVT::i32);
+      Remainder = DAG.getTargetConstant(Imm % 16, DL, MVT::i32);
       NeedsAdd = true;
-      VecNum = DAG.getConstant(Imm - (Imm % 15), DL, MVT::i32);
+      VecNum = DAG.getConstant(Imm - (Imm % 16), DL, MVT::i32);
     }
   } else if (VecNum.getOpcode() == ISD::ADD) {
     // If the vnum is an add, we can fold that add into the instruction if the
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
index 09e7d7b4068ce17..e32d1a170defc41 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
@@ -324,10 +324,9 @@ define void @ldr_with_off_16imm(ptr %base) {
 ; CHECK-LABEL: ldr_with_off_16imm:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    mov w12, #31 // =0x1f
-; CHECK-NEXT:    sub x9, x0, x8
-; CHECK-NEXT:    add x8, x9, x8, lsl #4
-; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    mov w12, #32 // =0x20
+; CHECK-NEXT:    add x8, x0, x8, lsl #4
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   call void @llvm.aarch64.sme.ldr(i32 16, ptr %base, i32 16)
   ret void;
@@ -350,6 +349,85 @@ entry:
   ret void
 }
 
+define void @ldr_with_off_many_imm_15_18(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: ldr_with_off_many_imm_15_18:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov w12, w0
+; CHECK-NEXT:    add x8, x1, x8, lsl #4
+; CHECK-NEXT:    ldr za[w12, 15], [x1, #15, mul vl]
+; CHECK-NEXT:    add w12, w0, #16
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    ldr za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 15)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 16)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 17)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 18)
+  ret void
+}
+
+define void @ldr_with_off_many_imm_16_19(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: ldr_with_off_many_imm_16_19:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    add w12, w0, #16
+; CHECK-NEXT:    add x8, x1, x8, lsl #4
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    ldr za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    ldr za[w12, 3], [x8, #3, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 16)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 17)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 18)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 19)
+  ret void
+}
+
+define void @ldr_with_off_many_imm_31_34(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: ldr_with_off_many_imm_31_34:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    add w12, w0, #16
+; CHECK-NEXT:    add x9, x1, x8, lsl #4
+; CHECK-NEXT:    add x8, x1, x8, lsl #5
+; CHECK-NEXT:    ldr za[w12, 15], [x9, #15, mul vl]
+; CHECK-NEXT:    add w12, w0, #32
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    ldr za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 31)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 32)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 33)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 34)
+  ret void
+}
+
+define void @ldr_with_off_many_imm_32_35(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+; CHECK-LABEL: ldr_with_off_many_imm_32_35:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    add w12, w0, #32
+; CHECK-NEXT:    add x8, x1, x8, lsl #5
+; CHECK-NEXT:    ldr za[w12, 0], [x8]
+; CHECK-NEXT:    ldr za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    ldr za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    ldr za[w12, 3], [x8, #3, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 32)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 33)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 34)
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 35)
+  ret void
+}
+
 define void @ldr_with_off_many_var(i32 %tile_slice, ptr %ptr, i64 %vnum) {
 ; CHECK-LABEL: ldr_with_off_many_var:
 ; CHECK:       // %bb.0: // %entry
@@ -380,13 +458,13 @@ define void @test_ld1_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src,
 ; CHECK-LABEL: test_ld1_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB19_1: // %for.body
+; CHECK-NEXT:  .LBB23_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 0]}, p0/z, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 1]}, p0/z, [x0]
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 2]}, p0/z, [x0]
-; CHECK-NEXT:    b.ne .LBB19_1
+; CHECK-NEXT:    b.ne .LBB23_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
index 40327b80a1b96d7..4843f9388fa2f77 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
@@ -326,18 +326,17 @@ define void @str_with_off_16imm(ptr %ptr) {
 ; CHECK-LABEL: str_with_off_16imm:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    rdsvl x8, #1
-; CHECK-NEXT:    mov w12, #30 // =0x1e
-; CHECK-NEXT:    sub x9, x0, x8
-; CHECK-NEXT:    add x8, x9, x8, lsl #4
+; CHECK-NEXT:    mov w12, #31 // =0x1f
+; CHECK-NEXT:    add x8, x0, x8, lsl #4
 ; CHECK-NEXT:    add x8, x8, #15
-; CHECK-NEXT:    str za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    str za[w12, 0], [x8]
 ; CHECK-NEXT:    ret
   %base = getelementptr i8, ptr %ptr, i64 15
   call void @llvm.aarch64.sme.str(i32 15, ptr %base, i32 16)
   ret void;
 }
 
-define void @str_with_off_many_imm(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+define void @str_with_off_many_imm(i32 %tile_slice, ptr %ptr) {
 ; CHECK-LABEL: str_with_off_many_imm:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w0
@@ -354,6 +353,85 @@ entry:
   ret void
 }
 
+define void @str_with_off_many_imm_15_18(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: str_with_off_many_imm_15_18:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    mov w12, w0
+; CHECK-NEXT:    add x8, x1, x8, lsl #4
+; CHECK-NEXT:    str za[w12, 15], [x1, #15, mul vl]
+; CHECK-NEXT:    add w12, w0, #16
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    str za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    str za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 15)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 16)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 17)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 18)
+  ret void
+}
+
+define void @str_with_off_many_imm_16_19(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: str_with_off_many_imm_16_19:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    add w12, w0, #16
+; CHECK-NEXT:    add x8, x1, x8, lsl #4
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    str za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    str za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    str za[w12, 3], [x8, #3, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 16)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 17)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 18)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 19)
+  ret void
+}
+
+define void @str_with_off_many_imm_31_34(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: str_with_off_many_imm_31_34:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    add w12, w0, #16
+; CHECK-NEXT:    add w13, w0, #32
+; CHECK-NEXT:    add x9, x1, x8, lsl #4
+; CHECK-NEXT:    add x8, x1, x8, lsl #5
+; CHECK-NEXT:    str za[w12, 15], [x9, #15, mul vl]
+; CHECK-NEXT:    str za[w13, 0], [x8]
+; CHECK-NEXT:    str za[w13, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    str za[w13, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 31)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 32)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 33)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 34)
+  ret void
+}
+
+define void @str_with_off_many_imm_32_35(i32 %tile_slice, ptr %ptr) {
+; CHECK-LABEL: str_with_off_many_imm_32_35:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    rdsvl x8, #1
+; CHECK-NEXT:    add w12, w0, #32
+; CHECK-NEXT:    add x8, x1, x8, lsl #5
+; CHECK-NEXT:    str za[w12, 0], [x8]
+; CHECK-NEXT:    str za[w12, 1], [x8, #1, mul vl]
+; CHECK-NEXT:    str za[w12, 2], [x8, #2, mul vl]
+; CHECK-NEXT:    str za[w12, 3], [x8, #3, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 32)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 33)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 34)
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 35)
+  ret void
+}
+
 define void @str_with_off_many_var(i32 %tile_slice, ptr %ptr, i64 %vnum) {
 ; CHECK-LABEL: str_with_off_many_var:
 ; CHECK:       // %bb.0: // %entry
@@ -384,13 +462,13 @@ define void @test_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src, i32
 ; CHECK-LABEL: test_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB19_1: // %for.body
+; CHECK-NEXT:  .LBB23_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 0]}, p0, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 1]}, p0, [x0]
 ; CHECK-NEXT:    st1w {za0h.s[w12, 2]}, p0, [x0]
-; CHECK-NEXT:    b.ne .LBB19_1
+; CHECK-NEXT:    b.ne .LBB23_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:

>From df71090d072ac5ecd1848b89ee1c694ef4abcce8 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 7 Nov 2023 13:59:37 +0000
Subject: [PATCH 08/12] fixup! move add check before range check

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 52 ++++++++++++-------
 .../CodeGen/AArch64/sme-intrinsics-loads.ll   | 30 ++++++++++-
 .../CodeGen/AArch64/sme-intrinsics-stores.ll  | 31 ++++++++++-
 3 files changed, 91 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 515ab3f61c34686..e3dd6a52fcadb63 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4865,32 +4865,48 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
   SDValue TileSlice = N->getOperand(2);
   SDValue Base = N->getOperand(3);
   SDValue VecNum = N->getOperand(4);
-  SDValue Remainder = DAG.getTargetConstant(0, DL, MVT::i32);
+  int Addend = 0;
+
+  // If the vnum is an add, we can fold that add into the instruction if the
+  // operand is an immediate. The range check is performed below.
+  if (VecNum.getOpcode() == ISD::ADD) {
+    if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum.getOperand(1))) {
+      Addend = ImmNode->getSExtValue();
+      VecNum = VecNum.getOperand(0);
+    }
+  }
+
+  SDValue Remainder = DAG.getTargetConstant(Addend, DL, MVT::i32);
 
   // true if the base and slice registers need to be modified
   bool NeedsAdd = true;
-  if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum)) {
-    int Imm = ImmNode->getSExtValue();
+  auto ImmNode = dyn_cast<ConstantSDNode>(VecNum);
+  if (ImmNode || Addend != 0) {
+    int Imm = ImmNode ? ImmNode->getSExtValue() + Addend : Addend;
+    Remainder = DAG.getTargetConstant(Imm % 16, DL, MVT::i32);
     if (Imm >= 0 && Imm <= 15) {
-      Remainder = DAG.getTargetConstant(Imm, DL, MVT::i32);
-      NeedsAdd = false;
+      // If vnum is an immediate in range then we don't need to modify the tile
+      // slice and base register. We could also get here because Addend != 0 but
+      // vecnum is not an immediate, in which case we still want the base and
+      // slice register to be modified
+      NeedsAdd = !ImmNode;
     } else {
-      Remainder = DAG.getTargetConstant(Imm % 16, DL, MVT::i32);
+      // If it isn't in range then we strip off the remainder and add the result
+      // to the base register and tile slice
       NeedsAdd = true;
-      VecNum = DAG.getConstant(Imm - (Imm % 16), DL, MVT::i32);
-    }
-  } else if (VecNum.getOpcode() == ISD::ADD) {
-    // If the vnum is an add, we can fold that add into the instruction if the
-    // operand is an immediate in range
-    if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum.getOperand(1))) {
-      int Imm = ImmNode->getSExtValue();
-      if (Imm >= 0 && Imm <= 15) {
-        VecNum = VecNum.getOperand(0);
-        Remainder = DAG.getTargetConstant(Imm, DL, MVT::i32);
-        NeedsAdd = true;
-      }
+      Imm -= Imm % 16;
+      // If the operand isn't an immediate and instead came from an ADD then we
+      // reconstruct the add but with a smaller operand. This means that
+      // successive loads and stores offset from each other can share the same
+      // ADD and have their own remainder in the instruction.
+      if (ImmNode)
+        VecNum = DAG.getConstant(Imm, DL, MVT::i32);
+      else
+        VecNum = DAG.getNode(ISD::ADD, DL, MVT::i32, VecNum,
+                             DAG.getConstant(Imm, DL, MVT::i32));
     }
   }
+
   if (NeedsAdd) {
     // Get the vector length that will be multiplied by vnum
     auto SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
index e32d1a170defc41..da764cf52445beb 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-loads.ll
@@ -452,19 +452,45 @@ entry:
   ret void
 }
 
+define void @ldr_with_off_many_var_high(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+; CHECK-LABEL: ldr_with_off_many_var_high:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    add w8, w2, #32
+; CHECK-NEXT:    rdsvl x10, #1
+; CHECK-NEXT:    sxtw x9, w8
+; CHECK-NEXT:    add w12, w0, w8
+; CHECK-NEXT:    madd x9, x10, x9, x1
+; CHECK-NEXT:    ldr za[w12, 1], [x9, #1, mul vl]
+; CHECK-NEXT:    ldr za[w12, 2], [x9, #2, mul vl]
+; CHECK-NEXT:    ldr za[w12, 3], [x9, #3, mul vl]
+; CHECK-NEXT:    ldr za[w12, 4], [x9, #4, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  %0 = trunc i64 %vnum to i32
+  %1 = add i32 %0, 33
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %1)
+  %2 = add i32 %0, 34
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %2)
+  %3 = add i32 %0, 35
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %3)
+  %4 = add i32 %0, 36
+  tail call void @llvm.aarch64.sme.ldr(i32 %tile_slice, ptr %ptr, i32 %4)
+  ret void
+}
+
 ; Ensure that the tile offset is sunk, given that this is likely to be an 'add'
 ; that's decomposed into a base + offset in ISel.
 define void @test_ld1_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src, i32 %base, i32 %N) {
 ; CHECK-LABEL: test_ld1_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB23_1: // %for.body
+; CHECK-NEXT:  .LBB24_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 0]}, p0/z, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 1]}, p0/z, [x0]
 ; CHECK-NEXT:    ld1w {za0h.s[w12, 2]}, p0/z, [x0]
-; CHECK-NEXT:    b.ne .LBB23_1
+; CHECK-NEXT:    b.ne .LBB24_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
index 4843f9388fa2f77..53e9b6300951c29 100644
--- a/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
+++ b/llvm/test/CodeGen/AArch64/sme-intrinsics-stores.ll
@@ -456,19 +456,46 @@ entry:
   ret void
 }
 
+define void @str_with_off_many_var_high(i32 %tile_slice, ptr %ptr, i64 %vnum) {
+; CHECK-LABEL: str_with_off_many_var_high:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    add w8, w2, #32
+; CHECK-NEXT:    rdsvl x10, #1
+; CHECK-NEXT:    sxtw x9, w8
+; CHECK-NEXT:    add w12, w0, w8
+; CHECK-NEXT:    madd x9, x10, x9, x1
+; CHECK-NEXT:    str za[w12, 1], [x9, #1, mul vl]
+; CHECK-NEXT:    str za[w12, 2], [x9, #2, mul vl]
+; CHECK-NEXT:    str za[w12, 3], [x9, #3, mul vl]
+; CHECK-NEXT:    str za[w12, 4], [x9, #4, mul vl]
+; CHECK-NEXT:    ret
+entry:
+  %0 = trunc i64 %vnum to i32
+  %1 = add i32 %0, 33
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %1)
+  %2 = add i32 %0, 34
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %2)
+  %3 = add i32 %0, 35
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %3)
+  %4 = add i32 %0, 36
+  tail call void @llvm.aarch64.sme.str(i32 %tile_slice, ptr %ptr, i32 %4)
+  ret void
+}
+
+
 ; Ensure that the tile offset is sunk, given that this is likely to be an 'add'
 ; that's decomposed into a base + offset in ISel.
 define void @test_sink_tile0_offset_operand(<vscale x 4 x i1> %pg, ptr %src, i32 %base, i32 %N) {
 ; CHECK-LABEL: test_sink_tile0_offset_operand:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov w12, w1
-; CHECK-NEXT:  .LBB23_1: // %for.body
+; CHECK-NEXT:  .LBB24_1: // %for.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 0]}, p0, [x0]
 ; CHECK-NEXT:    subs w2, w2, #1
 ; CHECK-NEXT:    st1w {za0h.s[w12, 1]}, p0, [x0]
 ; CHECK-NEXT:    st1w {za0h.s[w12, 2]}, p0, [x0]
-; CHECK-NEXT:    b.ne .LBB23_1
+; CHECK-NEXT:    b.ne .LBB24_1
 ; CHECK-NEXT:  // %bb.2: // %exit
 ; CHECK-NEXT:    ret
 entry:

>From 7b4cba1e2e8f185ae81544ce3818c2726a94a0a6 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 7 Nov 2023 14:26:07 +0000
Subject: [PATCH 09/12] fixup! add some examples above the function

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 44 ++++++++++++++++---
 1 file changed, 37 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e3dd6a52fcadb63..b0210375e4f6311 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4852,14 +4852,44 @@ SDValue AArch64TargetLowering::getPStateSM(SelectionDAG &DAG, SDValue Chain,
                      Mask);
 }
 
+// Lower an SME LDR/STR ZA intrinsic to LDR_ZA_PSEUDO or STR_ZA.
+// Case 1: If the vector number (vecnum) is an immediate in range, it gets
+// folded into the instruction
+//    ldr(%tileslice, %ptr, 11) -> ldr [%tileslice, 11], [%ptr, 11]
+// Case 2: If the vecnum is not an immediate, then it is used to modify the base
+// and tile slice registers
+//    ldr(%tileslice, %ptr, %vecnum)
+//    ->
+//    %svl = rdsvl
+//    %ptr2 = %ptr + %svl * %vecnum
+//    %tileslice2 = %tileslice + %vecnum
+//    ldr [%tileslice2, 0], [%ptr2, 0]
+// Case 3: If the vecnum is an immediate out of range, then the same is done as
+// case 2, but the base and slice registers are modified by the greatest
+// multiple of 15 lower than the vecnum and the remainder is folded into the
+// instruction. This means that successive loads and stores that are offset from
+// each other can share the same base and slice register updates.
+//    ldr(%tileslice, %ptr, 22)
+//    ldr(%tileslice, %ptr, 23)
+//    ->
+//    %svl = rdsvl
+//    %ptr2 = %ptr + %svl * 15
+//    %tileslice2 = %tileslice + 15
+//    ldr [%tileslice2, 7], [%ptr2, 7]
+//    ldr [%tileslice2, 8], [%ptr2, 8]
+// Case 4: If the vecnum is an add of an immediate, then the non-immediate
+// operand and the immediate can be folded into the instruction, like case 2.
+//    ldr(%tileslice, %ptr, %vecnum + 7)
+//    ldr(%tileslice, %ptr, %vecnum + 8)
+//    ->
+//    %svl = rdsvl
+//    %ptr2 = %ptr + %svl * %vecnum
+//    %tileslice2 = %tileslice + %vecnum
+//    ldr [%tileslice2, 7], [%ptr2, 7]
+//    ldr [%tileslice2, 8], [%ptr2, 8]
+// Case 5: The vecnum being an add of an immediate out of range is also handled,
+// in which case the same remainder logic as case 3 is used.
 SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
-  // Lower an SME LDR/STR ZA intrinsic to LDR_ZA_PSEUDO or STR_ZA.
-  // If the vector number is an immediate between 0 and 15 inclusive then we can
-  // put that directly into the immediate field of the instruction. If it's
-  // outside of that range then we modify the base and slice by the greatest
-  // multiple of 15 smaller than that number and put the remainder in the
-  // instruction field. If it's not an immediate then we modify the base and
-  // slice registers by that number and put 0 in the instruction.
   SDLoc DL(N);
 
   SDValue TileSlice = N->getOperand(2);

>From 6c53d8f30205a2c6ff8f55c5db29c597d7e25af8 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 8 Nov 2023 11:33:59 +0000
Subject: [PATCH 10/12] fixup! fix mlir test

---
 llvm/include/llvm/IR/IntrinsicsAArch64.td           |  2 +-
 .../mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td    |  3 ++-
 mlir/test/Target/LLVMIR/arm-sme.mlir                | 13 ++++++++++++-
 3 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index 543b5b6fa94d40d..d68bac0409d82ee 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -2680,7 +2680,7 @@ let TargetPrefix = "aarch64" in {
 
   // Spill + fill
   class SME_LDR_STR_Intrinsic
-    : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
+    : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty, llvm_i32_ty]>;
   def int_aarch64_sme_ldr : SME_LDR_STR_Intrinsic;
   def int_aarch64_sme_str : SME_LDR_STR_Intrinsic;
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index bcf2466b13a739f..b75918ebf2f6d9c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -133,7 +133,8 @@ def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
 def LLVM_aarch64_sme_str
     : ArmSME_IntrOp<"str">,
       Arguments<(ins Arg<I32, "Index">:$index,
-                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address)>;
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
+                 Arg<I32, "Offset">:$offset)>;
 
 // Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index aa0389e888b60d6..e718595f2cf7dbe 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -214,7 +214,18 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
   "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
               (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.str
-  "arm_sme.intr.str"(%c0, %ptr) : (i32, !llvm.ptr) -> ()
+  "arm_sme.intr.str"(%c0, %p8, %c0) : (i32, !llvm.ptr<i8>, i32) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_toggle_za
+llvm.func @arm_sme_toggle_za() {
+  // CHECK: call void @llvm.aarch64.sme.za.enable()
+  "arm_sme.intr.za.enable"() : () -> ()
+  // CHECK: call void @llvm.aarch64.sme.za.disable()
+  "arm_sme.intr.za.disable"() : () -> ()
   llvm.return
 }
 

>From 5544c02854a11814f0276ccebb910952936be0d8 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 14 Nov 2023 17:14:42 +0000
Subject: [PATCH 11/12] simplify code

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 74 +++++++------------
 1 file changed, 27 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b0210375e4f6311..cc433bd0930caa0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4895,64 +4895,44 @@ SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
   SDValue TileSlice = N->getOperand(2);
   SDValue Base = N->getOperand(3);
   SDValue VecNum = N->getOperand(4);
-  int Addend = 0;
-
-  // If the vnum is an add, we can fold that add into the instruction if the
-  // operand is an immediate. The range check is performed below.
-  if (VecNum.getOpcode() == ISD::ADD) {
-    if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum.getOperand(1))) {
-      Addend = ImmNode->getSExtValue();
-      VecNum = VecNum.getOperand(0);
-    }
-  }
-
-  SDValue Remainder = DAG.getTargetConstant(Addend, DL, MVT::i32);
-
-  // true if the base and slice registers need to be modified
-  bool NeedsAdd = true;
-  auto ImmNode = dyn_cast<ConstantSDNode>(VecNum);
-  if (ImmNode || Addend != 0) {
-    int Imm = ImmNode ? ImmNode->getSExtValue() + Addend : Addend;
-    Remainder = DAG.getTargetConstant(Imm % 16, DL, MVT::i32);
-    if (Imm >= 0 && Imm <= 15) {
-      // If vnum is an immediate in range then we don't need to modify the tile
-      // slice and base register. We could also get here because Addend != 0 but
-      // vecnum is not an immediate, in which case we still want the base and
-      // slice register to be modified
-      NeedsAdd = !ImmNode;
-    } else {
-      // If it isn't in range then we strip off the remainder and add the result
-      // to the base register and tile slice
-      NeedsAdd = true;
-      Imm -= Imm % 16;
-      // If the operand isn't an immediate and instead came from an ADD then we
-      // reconstruct the add but with a smaller operand. This means that
-      // successive loads and stores offset from each other can share the same
-      // ADD and have their own remainder in the instruction.
-      if (ImmNode)
-        VecNum = DAG.getConstant(Imm, DL, MVT::i32);
-      else
-        VecNum = DAG.getNode(ISD::ADD, DL, MVT::i32, VecNum,
-                             DAG.getConstant(Imm, DL, MVT::i32));
-    }
+  int32_t ConstAddend = 0;
+  SDValue VarAddend = VecNum;
+
+  // If the vnum is an add of an immediate, we can fold it into the instruction
+  if (VecNum.getOpcode() == ISD::ADD &&
+      isa<ConstantSDNode>(VecNum.getOperand(1))) {
+    ConstAddend = cast<ConstantSDNode>(VecNum.getOperand(1))->getSExtValue();
+    VarAddend = VecNum.getOperand(0);
+  } else if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum)) {
+    ConstAddend = ImmNode->getSExtValue();
+    VarAddend = SDValue();
+  }
+
+  int32_t ImmAddend = ConstAddend % 16;
+  if (int32_t C = (ConstAddend - ImmAddend)) {
+    SDValue CVal = DAG.getTargetConstant(C, DL, MVT::i32);
+    VarAddend = VarAddend
+                    ? DAG.getNode(ISD::ADD, DL, MVT::i32, {VarAddend, CVal})
+                    : CVal;
   }
 
-  if (NeedsAdd) {
+  if (VarAddend) {
     // Get the vector length that will be multiplied by vnum
     auto SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
                            DAG.getConstant(1, DL, MVT::i32));
 
     // Multiply SVL and vnum then add it to the base
-    SDValue Mul =
-        DAG.getNode(ISD::MUL, DL, MVT::i64,
-                    {SVL, DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VecNum)});
+    SDValue Mul = DAG.getNode(
+        ISD::MUL, DL, MVT::i64,
+        {SVL, DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VarAddend)});
     Base = DAG.getNode(ISD::ADD, DL, MVT::i64, {Base, Mul});
     // Just add vnum to the tileslice
-    TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, {TileSlice, VecNum});
+    TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, {TileSlice, VarAddend});
   }
 
-  SmallVector<SDValue, 4> Ops = {/*Chain=*/N.getOperand(0), TileSlice, Base,
-                                 Remainder};
+  SmallVector<SDValue, 4> Ops = {
+      /*Chain=*/N.getOperand(0), TileSlice, Base,
+      DAG.getTargetConstant(ImmAddend, DL, MVT::i32)};
   auto LdrStr =
       DAG.getNode(IsLoad ? AArch64ISD::SME_ZA_LDR : AArch64ISD::SME_ZA_STR, DL,
                   MVT::Other, Ops);

>From b03e02a7e1bb8c98de3f3c0a066a376a9b45fd8c Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 15 Nov 2023 13:12:23 +0000
Subject: [PATCH 12/12] separate out SME_LDR_STR_Intrinsic as its now used by
 zt ldr/str

---
 llvm/include/llvm/IR/IntrinsicsAArch64.td | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index d68bac0409d82ee..1b701a91455c946 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -2679,10 +2679,10 @@ let TargetPrefix = "aarch64" in {
   def int_aarch64_sme_st1q_vert  : SME_Load_Store_Intrinsic<llvm_nxv1i1_ty>;
 
   // Spill + fill
-  class SME_LDR_STR_Intrinsic
+  class SME_LDR_STR_ZA_Intrinsic
     : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty, llvm_i32_ty]>;
-  def int_aarch64_sme_ldr : SME_LDR_STR_Intrinsic;
-  def int_aarch64_sme_str : SME_LDR_STR_Intrinsic;
+  def int_aarch64_sme_ldr : SME_LDR_STR_ZA_Intrinsic;
+  def int_aarch64_sme_str : SME_LDR_STR_ZA_Intrinsic;
 
   class SME_TileToVector_Intrinsic
       : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
@@ -3454,7 +3454,9 @@ let TargetPrefix = "aarch64" in {
   def int_aarch64_sve_sel_x2  : SVE2_VG2_Sel_Intrinsic;
   def int_aarch64_sve_sel_x4  : SVE2_VG4_Sel_Intrinsic;
 
-  def int_aarch64_sme_ldr_zt : SME_LDR_STR_Intrinsic;
-  def int_aarch64_sme_str_zt : SME_LDR_STR_Intrinsic;
+  class SME_LDR_STR_ZT_Intrinsic
+    : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_ptr_ty]>;
+  def int_aarch64_sme_ldr_zt : SME_LDR_STR_ZT_Intrinsic;
+  def int_aarch64_sme_str_zt : SME_LDR_STR_ZT_Intrinsic;
 
 }



More information about the cfe-commits mailing list