[clang] 7da1905 - [AArch32] Armv8.6-a Matrix Mult Assembly + Intrinsics

Luke Geeson via cfe-commits cfe-commits at lists.llvm.org
Fri Apr 24 07:54:30 PDT 2020


Author: Luke Geeson
Date: 2020-04-24T15:54:06+01:00
New Revision: 7da19051253219d4bee2c50fe13f250201f1f7ec

URL: https://github.com/llvm/llvm-project/commit/7da19051253219d4bee2c50fe13f250201f1f7ec
DIFF: https://github.com/llvm/llvm-project/commit/7da19051253219d4bee2c50fe13f250201f1f7ec.diff

LOG: [AArch32] Armv8.6-a Matrix Mult Assembly + Intrinsics

This patch upstreams support for the Armv8.6-a Matrix Multiplication
Extension. A summary of the features can be found here:

https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/arm-architecture-developments-armv8-6-a

This patch includes:

- Assembly support for AArch32
- Intrinsics Support for AArch32 Neon Intrinsics for Matrix
  Multiplication

Note: these extensions are optional in the 8.6a architecture and so have
to be enabled by default

No additional IR types or C Types are needed for this extension.

This is part of a patch series, starting with BFloat16 support and
the other components in the armv8.6a extension (in previous patches
linked in phabricator)

Based on work by:
- Luke Geeson
- Oliver Stannard
- Luke Cheeseman

Reviewers: t.p.northover, miyuki

Reviewed By: miyuki

Subscribers: miyuki, ostannard, kristof.beyls, hiraditya, danielkiss,
cfe-commits

Tags: #clang

Differential Revision: https://reviews.llvm.org/D77872

Added: 
    clang/test/CodeGen/arm-v8.6a-neon-intrinsics.c
    llvm/test/CodeGen/ARM/arm-matmul.ll

Modified: 
    clang/lib/Basic/Targets/ARM.cpp
    clang/lib/Basic/Targets/ARM.h
    clang/lib/CodeGen/CGBuiltin.cpp
    llvm/include/llvm/IR/IntrinsicsARM.td
    llvm/lib/Target/ARM/ARM.td
    llvm/lib/Target/ARM/ARMInstrNEON.td
    llvm/lib/Target/ARM/ARMPredicates.td
    llvm/lib/Target/ARM/ARMSubtarget.h

Removed: 
    


################################################################################
diff  --git a/clang/lib/Basic/Targets/ARM.cpp b/clang/lib/Basic/Targets/ARM.cpp
index 881516181538..f02a9d373609 100644
--- a/clang/lib/Basic/Targets/ARM.cpp
+++ b/clang/lib/Basic/Targets/ARM.cpp
@@ -425,6 +425,7 @@ bool ARMTargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
   // Note that SoftFloatABI is initialized in our constructor.
   HWDiv = 0;
   DotProd = 0;
+  HasMatMul = 0;
   HasFloat16 = true;
   ARMCDECoprocMask = 0;
 
@@ -491,6 +492,8 @@ bool ARMTargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
       FPU |= FPARMV8;
       MVE |= MVE_INT | MVE_FP;
       HW_FP |= HW_FP_SP | HW_FP_HP;
+    } else if (Feature == "+i8mm") {
+      HasMatMul = 1;
     } else if (Feature.size() == strlen("+cdecp0") && Feature >= "+cdecp0" &&
                Feature <= "+cdecp7") {
       unsigned Coproc = Feature.back() - '0';
@@ -820,6 +823,9 @@ void ARMTargetInfo::getTargetDefines(const LangOptions &Opts,
   if (DotProd)
     Builder.defineMacro("__ARM_FEATURE_DOTPROD", "1");
 
+  if (HasMatMul)
+    Builder.defineMacro("__ARM_FEATURE_MATMUL_INT8", "1");
+
   switch (ArchKind) {
   default:
     break;

diff  --git a/clang/lib/Basic/Targets/ARM.h b/clang/lib/Basic/Targets/ARM.h
index 725954038602..48d9db2ba166 100644
--- a/clang/lib/Basic/Targets/ARM.h
+++ b/clang/lib/Basic/Targets/ARM.h
@@ -75,6 +75,7 @@ class LLVM_LIBRARY_VISIBILITY ARMTargetInfo : public TargetInfo {
   unsigned DSP : 1;
   unsigned Unaligned : 1;
   unsigned DotProd : 1;
+  unsigned HasMatMul : 1;
 
   enum {
     LDREX_B = (1 << 0), /// byte (8-bit)

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 80b27b629d1e..3a66583e20e8 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -4807,6 +4807,7 @@ static const ARMVectorIntrinsicInfo ARMSIMDIntrinsicMap [] = {
   NEONMAP1(vminnm_v, arm_neon_vminnm, Add1ArgType),
   NEONMAP1(vminnmq_v, arm_neon_vminnm, Add1ArgType),
   NEONMAP2(vminq_v, arm_neon_vminu, arm_neon_vmins, Add1ArgType | UnsignedAlts),
+  NEONMAP2(vmmlaq_v, arm_neon_ummla, arm_neon_smmla, 0),
   NEONMAP0(vmovl_v),
   NEONMAP0(vmovn_v),
   NEONMAP1(vmul_v, arm_neon_vmulp, Add1ArgType),
@@ -4914,6 +4915,9 @@ static const ARMVectorIntrinsicInfo ARMSIMDIntrinsicMap [] = {
   NEONMAP0(vtrnq_v),
   NEONMAP0(vtst_v),
   NEONMAP0(vtstq_v),
+  NEONMAP1(vusdot_v, arm_neon_usdot, 0),
+  NEONMAP1(vusdotq_v, arm_neon_usdot, 0),
+  NEONMAP1(vusmmlaq_v, arm_neon_usmmla, 0),
   NEONMAP0(vuzp_v),
   NEONMAP0(vuzpq_v),
   NEONMAP0(vzip_v),

diff  --git a/clang/test/CodeGen/arm-v8.6a-neon-intrinsics.c b/clang/test/CodeGen/arm-v8.6a-neon-intrinsics.c
new file mode 100644
index 000000000000..a641197b118a
--- /dev/null
+++ b/clang/test/CodeGen/arm-v8.6a-neon-intrinsics.c
@@ -0,0 +1,87 @@
+// RUN: %clang_cc1 -triple armv8.6a-arm-none-eabi -target-feature +neon -target-feature +fullfp16 -target-feature +i8mm \
+// RUN: -fallow-half-arguments-and-returns -S -disable-O0-optnone -emit-llvm -o - %s \
+// RUN: | opt -S -mem2reg -sroa \
+// RUN: | FileCheck %s
+
+// REQUIRES: arm-registered-target
+
+#include <arm_neon.h>
+
+// CHECK-LABEL: test_vmmlaq_s32
+// CHECK: [[VAL:%.*]] = call <4 x i32> @llvm.arm.neon.smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
+// CHECK: ret <4 x i32> [[VAL]]
+int32x4_t test_vmmlaq_s32(int32x4_t r, int8x16_t a, int8x16_t b) {
+  return vmmlaq_s32(r, a, b);
+}
+
+// CHECK-LABEL: test_vmmlaq_u32
+// CHECK: [[VAL:%.*]] = call <4 x i32> @llvm.arm.neon.ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
+// CHECK: ret <4 x i32> [[VAL]]
+uint32x4_t test_vmmlaq_u32(uint32x4_t r, uint8x16_t a, uint8x16_t b) {
+  return vmmlaq_u32(r, a, b);
+}
+
+// CHECK-LABEL: test_vusmmlaq_s32
+// CHECK: [[VAL:%.*]] = call <4 x i32> @llvm.arm.neon.usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b)
+// CHECK: ret <4 x i32> [[VAL]]
+int32x4_t test_vusmmlaq_s32(int32x4_t r, uint8x16_t a, int8x16_t b) {
+  return vusmmlaq_s32(r, a, b);
+}
+
+// CHECK-LABEL: test_vusdot_s32
+// CHECK: [[VAL:%.*]] = call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b)
+// CHECK: ret <2 x i32> [[VAL]]
+int32x2_t test_vusdot_s32(int32x2_t r, uint8x8_t a, int8x8_t b) {
+  return vusdot_s32(r, a, b);
+}
+
+// CHECK-LABEL: test_vusdot_lane_s32
+// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
+// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to <8 x i8>
+// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <2 x i32>
+// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <2 x i32> zeroinitializer
+// CHECK: [[TMP3:%.*]] = bitcast <2 x i32> [[LANE]] to <8 x i8>
+// CHECK: [[TMP4:%.*]] = bitcast <2 x i32> %r to <8 x i8>
+// CHECK: [[OP:%.*]] = call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> [[TMP3]])
+// CHECK: ret <2 x i32> [[OP]]
+int32x2_t test_vusdot_lane_s32(int32x2_t r, uint8x8_t a, int8x8_t b) {
+  return vusdot_lane_s32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vsudot_lane_s32
+// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
+// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to <8 x i8>
+// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <2 x i32>
+// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <2 x i32> zeroinitializer
+// CHECK: [[TMP3:%.*]] = bitcast <2 x i32> [[LANE]] to <8 x i8>
+// CHECK: [[TMP4:%.*]] = bitcast <2 x i32> %r to <8 x i8>
+// CHECK: [[OP:%.*]] = call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> [[TMP3]], <8 x i8> %a)
+// CHECK: ret <2 x i32> [[OP]]
+int32x2_t test_vsudot_lane_s32(int32x2_t r, int8x8_t a, uint8x8_t b) {
+  return vsudot_lane_s32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vusdotq_lane_s32
+// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
+// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to <8 x i8>
+// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <2 x i32>
+// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <4 x i32> zeroinitializer
+// CHECK: [[TMP4:%.*]] = bitcast <4 x i32> [[LANE]] to <16 x i8>
+// CHECK: [[TMP5:%.*]] = bitcast <4 x i32> %r to <16 x i8>
+// CHECK: [[OP:%.*]] = call <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> [[TMP4]])
+// CHECK: ret <4 x i32> [[OP]]
+int32x4_t test_vusdotq_lane_s32(int32x4_t r, uint8x16_t a, int8x8_t b) {
+  return vusdotq_lane_s32(r, a, b, 0);
+}
+
+// CHECK-LABEL: test_vsudotq_lane_s32
+// CHECK: [[TMP0:%.*]] = bitcast <8 x i8> %b to <2 x i32>
+// CHECK: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to <8 x i8>
+// CHECK: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <2 x i32>
+// CHECK: [[LANE:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP2]], <4 x i32> zeroinitializer
+// CHECK: [[TMP4:%.*]] = bitcast <4 x i32> %r to <16 x i8>
+// CHECK: [[OP:%.*]] = call <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %3, <16 x i8> %a)
+// CHECK: ret <4 x i32> [[OP]]
+int32x4_t test_vsudotq_lane_s32(int32x4_t r, int8x16_t a, uint8x8_t b) {
+  return vsudotq_lane_s32(r, a, b, 0);
+}

diff  --git a/llvm/include/llvm/IR/IntrinsicsARM.td b/llvm/include/llvm/IR/IntrinsicsARM.td
index a87ab5f596aa..1606d666fa6a 100644
--- a/llvm/include/llvm/IR/IntrinsicsARM.td
+++ b/llvm/include/llvm/IR/IntrinsicsARM.td
@@ -773,6 +773,19 @@ class Neon_Dot_Intrinsic
 def int_arm_neon_udot : Neon_Dot_Intrinsic;
 def int_arm_neon_sdot : Neon_Dot_Intrinsic;
 
+// v8.6-A Matrix Multiply Intrinsics
+class Neon_MatMul_Intrinsic
+  : Intrinsic<[llvm_anyvector_ty],
+              [LLVMMatchType<0>, llvm_anyvector_ty,
+               LLVMMatchType<1>],
+              [IntrNoMem]>;
+def int_arm_neon_ummla  : Neon_MatMul_Intrinsic;
+def int_arm_neon_smmla  : Neon_MatMul_Intrinsic;
+def int_arm_neon_usmmla : Neon_MatMul_Intrinsic;
+def int_arm_neon_usdot  : Neon_Dot_Intrinsic;
+
+// v8.6-A Bfloat Intrinsics
+
 def int_arm_cls: Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>;
 def int_arm_cls64: Intrinsic<[llvm_i32_ty], [llvm_i64_ty], [IntrNoMem]>;
 

diff  --git a/llvm/lib/Target/ARM/ARM.td b/llvm/lib/Target/ARM/ARM.td
index 72cfd64ac566..2b4c3e13014c 100644
--- a/llvm/lib/Target/ARM/ARM.td
+++ b/llvm/lib/Target/ARM/ARM.td
@@ -428,6 +428,9 @@ def FeatureSB       : SubtargetFeature<"sb", "HasSB", "true",
 def FeatureBF16     : SubtargetFeature<"bf16", "HasBF16", "true",
   "Enable support for BFloat16 instructions",  [FeatureNEON]>;
 
+def FeatureMatMulInt8 : SubtargetFeature<"i8mm", "HasMatMulInt8",
+    "true", "Enable Matrix Multiply Int8 Extension", [FeatureNEON]>;
+
 // Armv8.1-M extensions
 
 def FeatureLOB            : SubtargetFeature<"lob", "HasLOB", "true",
@@ -529,7 +532,8 @@ def HasV8_5aOps   : SubtargetFeature<"v8.5a", "HasV8_5aOps", "true",
 
 def HasV8_6aOps   : SubtargetFeature<"v8.6a", "HasV8_6aOps", "true",
                                    "Support ARM v8.6a instructions",
-                                   [HasV8_5aOps, FeatureBF16]>;
+                                   [HasV8_5aOps, FeatureBF16,
+                                    FeatureMatMulInt8]>;
 
 def HasV8_1MMainlineOps : SubtargetFeature<
                "v8.1m.main", "HasV8_1MMainlineOps", "true",

diff  --git a/llvm/lib/Target/ARM/ARMInstrNEON.td b/llvm/lib/Target/ARM/ARMInstrNEON.td
index 44338a2d4dbb..7dbbd4796a03 100644
--- a/llvm/lib/Target/ARM/ARMInstrNEON.td
+++ b/llvm/lib/Target/ARM/ARMInstrNEON.td
@@ -4823,10 +4823,10 @@ def : Pat<(v4f32 (fma (fneg QPR:$Vn), QPR:$Vm, QPR:$src1)),
 // We put them in the VFPV8 decoder namespace because the ARM and Thumb
 // encodings are the same and thus no further bit twiddling is necessary
 // in the disassembler.
-class VDOT<bit op6, bit op4, RegisterClass RegTy, string Asm, string AsmTy,
-           ValueType AccumTy, ValueType InputTy,
+class VDOT<bit op6, bit op4, bit op23, RegisterClass RegTy, string Asm,
+           string AsmTy, ValueType AccumTy, ValueType InputTy,
            SDPatternOperator OpNode> :
-      N3Vnp<0b11000, 0b10, 0b1101, op6, op4, (outs RegTy:$dst),
+      N3Vnp<{0b1100, op23}, 0b10, 0b1101, op6, op4, (outs RegTy:$dst),
             (ins RegTy:$Vd, RegTy:$Vn, RegTy:$Vm), N3RegFrm, IIC_VDOTPROD,
             Asm, AsmTy,
             [(set (AccumTy RegTy:$dst),
@@ -4838,10 +4838,19 @@ class VDOT<bit op6, bit op4, RegisterClass RegTy, string Asm, string AsmTy,
   let Constraints = "$dst = $Vd";
 }
 
-def VUDOTD : VDOT<0, 1, DPR, "vudot", "u8", v2i32, v8i8,  int_arm_neon_udot>;
-def VSDOTD : VDOT<0, 0, DPR, "vsdot", "s8", v2i32, v8i8,  int_arm_neon_sdot>;
-def VUDOTQ : VDOT<1, 1, QPR, "vudot", "u8", v4i32, v16i8, int_arm_neon_udot>;
-def VSDOTQ : VDOT<1, 0, QPR, "vsdot", "s8", v4i32, v16i8, int_arm_neon_sdot>;
+
+class VUSDOT<bit op6, bit op4, bit op23, RegisterClass RegTy, string Asm,
+           string AsmTy, ValueType AccumTy, ValueType InputTy,
+           SDPatternOperator OpNode> :
+      VDOT<op6, op4, op23, RegTy, Asm, AsmTy, AccumTy, InputTy, OpNode> {
+  let hasNoSchedulingInfo = 1;
+
+}
+
+def VUDOTD : VDOT<0, 1, 0, DPR, "vudot", "u8", v2i32, v8i8,  int_arm_neon_udot>;
+def VSDOTD : VDOT<0, 0, 0, DPR, "vsdot", "s8", v2i32, v8i8,  int_arm_neon_sdot>;
+def VUDOTQ : VDOT<1, 1, 0, QPR, "vudot", "u8", v4i32, v16i8, int_arm_neon_udot>;
+def VSDOTQ : VDOT<1, 0, 0, QPR, "vsdot", "s8", v4i32, v16i8, int_arm_neon_sdot>;
 
 // Indexed dot product instructions:
 multiclass DOTI<string opc, string dt, bit Q, bit U, RegisterClass Ty,
@@ -4876,6 +4885,70 @@ defm VUDOTQI : DOTI<"vudot", "u8", 0b1, 0b1, QPR, v4i32, v16i8,
 defm VSDOTQI : DOTI<"vsdot", "s8", 0b1, 0b0, QPR, v4i32, v16i8,
                     int_arm_neon_sdot, (EXTRACT_SUBREG QPR:$Vm, dsub_0)>;
 
+// v8.6A matrix multiplication extension
+let Predicates = [HasMatMulInt8] in {
+  class N3VMatMul<bit B, bit U, string Asm, string AsmTy,
+                  SDPatternOperator OpNode>
+        : N3Vnp<{0b1100, B}, 0b10, 0b1100, 1, U, (outs QPR:$dst),
+                (ins QPR:$Vd, QPR:$Vn, QPR:$Vm), N3RegFrm, NoItinerary,
+                Asm, AsmTy,
+                [(set (v4i32 QPR:$dst), (OpNode (v4i32 QPR:$Vd),
+                                                (v16i8 QPR:$Vn),
+                                                (v16i8 QPR:$Vm)))]> {
+    let DecoderNamespace = "VFPV8";
+    let Constraints = "$dst = $Vd";
+    let hasNoSchedulingInfo = 1;
+  }
+
+  multiclass N3VMixedDotLane<bit Q, bit U, string Asm, string AsmTy, RegisterClass RegTy,
+                        ValueType AccumTy, ValueType InputTy, SDPatternOperator OpNode,
+                        dag RHS> {
+
+    def "" : N3Vnp<0b11101, 0b00, 0b1101, Q, U, (outs RegTy:$dst),
+                (ins RegTy:$Vd, RegTy:$Vn, DPR_VFP2:$Vm, VectorIndex32:$lane), N3RegFrm,
+                 NoItinerary, Asm, AsmTy, []> {
+      bit lane;
+      let hasNoSchedulingInfo = 1;
+      let Inst{5} = lane;
+      let AsmString = !strconcat(Asm, ".", AsmTy, "\t$Vd, $Vn, $Vm$lane");
+      let DecoderNamespace = "VFPV8";
+      let Constraints = "$dst = $Vd";
+    }
+
+    def : Pat<
+      (AccumTy (OpNode (AccumTy RegTy:$Vd),
+                       (InputTy RegTy:$Vn),
+                       (InputTy (bitconvert (AccumTy
+                                (ARMvduplane (AccumTy RegTy:$Vm),
+                                              VectorIndex32:$lane)))))),
+      (!cast<Instruction>(NAME) RegTy:$Vd, RegTy:$Vn, RHS, VectorIndex32:$lane)>;
+
+  }
+
+  multiclass SUDOTLane<bit Q, RegisterClass RegTy, ValueType AccumTy, ValueType InputTy, dag RHS>
+        : N3VMixedDotLane<Q, 1, "vsudot", "u8", RegTy, AccumTy, InputTy, null_frag, null_frag> {
+    def : Pat<
+      (AccumTy (int_arm_neon_usdot (AccumTy RegTy:$Vd),
+                                   (InputTy (bitconvert (AccumTy
+                                            (ARMvduplane (AccumTy RegTy:$Vm),
+                                                          VectorIndex32:$lane)))),
+                                   (InputTy RegTy:$Vn))),
+      (!cast<Instruction>(NAME) RegTy:$Vd, RegTy:$Vn, RHS, VectorIndex32:$lane)>;
+  }
+
+  def VSMMLA  : N3VMatMul<0, 0, "vsmmla",  "s8", int_arm_neon_smmla>;
+  def VUMMLA  : N3VMatMul<0, 1, "vummla",  "u8", int_arm_neon_ummla>;
+  def VUSMMLA : N3VMatMul<1, 0, "vusmmla", "s8", int_arm_neon_usmmla>;
+  def VUSDOTD : VUSDOT<0, 0, 1, DPR, "vusdot", "s8", v2i32, v8i8,  int_arm_neon_usdot>;
+  def VUSDOTQ : VUSDOT<1, 0, 1, QPR, "vusdot", "s8", v4i32, v16i8, int_arm_neon_usdot>;
+
+  defm VUSDOTDI : N3VMixedDotLane<0, 0, "vusdot", "s8", DPR, v2i32, v8i8,
+                                  int_arm_neon_usdot, (v2i32 DPR_VFP2:$Vm)>;
+  defm VUSDOTQI : N3VMixedDotLane<1, 0, "vusdot", "s8", QPR, v4i32, v16i8,
+                                  int_arm_neon_usdot, (EXTRACT_SUBREG QPR:$Vm, dsub_0)>;
+  defm VSUDOTDI : SUDOTLane<0, DPR, v2i32, v8i8, (v2i32 DPR_VFP2:$Vm)>;
+  defm VSUDOTQI : SUDOTLane<1, QPR, v4i32, v16i8, (EXTRACT_SUBREG QPR:$Vm, dsub_0)>;
+}
 
 // ARMv8.3 complex operations
 class BaseN3VCP8ComplexTied<bit op21, bit op4, bit s, bit q,

diff  --git a/llvm/lib/Target/ARM/ARMPredicates.td b/llvm/lib/Target/ARM/ARMPredicates.td
index f2cb2cc7e309..19106265c6d1 100644
--- a/llvm/lib/Target/ARM/ARMPredicates.td
+++ b/llvm/lib/Target/ARM/ARMPredicates.td
@@ -110,6 +110,8 @@ def HasFP16FML       : Predicate<"Subtarget->hasFP16FML()">,
                                  AssemblerPredicate<(all_of FeatureFP16FML),"full half-float fml">;
 def HasBF16          : Predicate<"Subtarget->hasBF16()">,
                                  AssemblerPredicate<(all_of FeatureBF16),"BFloat16 floating point extension">;
+def HasMatMulInt8    : Predicate<"Subtarget->hasMatMulInt8()">,
+                                 AssemblerPredicate<(all_of FeatureMatMulInt8),"8-bit integer matrix multiply">;
 def HasDivideInThumb : Predicate<"Subtarget->hasDivideInThumbMode()">,
                                  AssemblerPredicate<(all_of FeatureHWDivThumb), "divide in THUMB">;
 def HasDivideInARM   : Predicate<"Subtarget->hasDivideInARMMode()">,

diff  --git a/llvm/lib/Target/ARM/ARMSubtarget.h b/llvm/lib/Target/ARM/ARMSubtarget.h
index 66c50fb590d9..536352190c11 100644
--- a/llvm/lib/Target/ARM/ARMSubtarget.h
+++ b/llvm/lib/Target/ARM/ARMSubtarget.h
@@ -260,6 +260,9 @@ class ARMSubtarget : public ARMGenSubtargetInfo {
   /// HasBF16 - True if subtarget supports BFloat16 floating point operations
   bool HasBF16 = false;
 
+  /// HasMatMulInt8 - True if subtarget supports 8-bit integer matrix multiply
+  bool HasMatMulInt8 = false;
+
   /// HasD32 - True if subtarget has the full 32 double precision
   /// FP registers for VFPv3.
   bool HasD32 = false;
@@ -704,6 +707,8 @@ class ARMSubtarget : public ARMGenSubtargetInfo {
   /// Return true if the CPU supports any kind of instruction fusion.
   bool hasFusion() const { return hasFuseAES() || hasFuseLiterals(); }
 
+  bool hasMatMulInt8() const { return HasMatMulInt8; }
+
   const Triple &getTargetTriple() const { return TargetTriple; }
 
   bool isTargetDarwin() const { return TargetTriple.isOSDarwin(); }

diff  --git a/llvm/test/CodeGen/ARM/arm-matmul.ll b/llvm/test/CodeGen/ARM/arm-matmul.ll
new file mode 100644
index 000000000000..6938fbea49e3
--- /dev/null
+++ b/llvm/test/CodeGen/ARM/arm-matmul.ll
@@ -0,0 +1,83 @@
+; RUN: llc -mtriple=arm-none-linux-gnu -mattr=+neon,+i8mm -float-abi=hard < %s -o -| FileCheck %s
+
+define <4 x i32> @smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+entry:
+; CHECK-LABEL: smmla.v4i32.v16i8
+; CHECK:        vsmmla.s8       q0, q1, q2
+  %vmmla1.i = tail call <4 x i32> @llvm.arm.neon.smmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
+  ret <4 x i32> %vmmla1.i
+}
+
+define <4 x i32> @ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+entry:
+; CHECK-LABEL: ummla.v4i32.v16i8
+; CHECK:        vummla.u8       q0, q1, q2
+  %vmmla1.i = tail call <4 x i32> @llvm.arm.neon.ummla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
+  ret <4 x i32> %vmmla1.i
+}
+
+define <4 x i32> @usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) {
+entry:
+; CHECK-LABEL: usmmla.v4i32.v16i8
+; CHECK:        vusmmla.s8       q0, q1, q2
+  %vusmmla1.i = tail call <4 x i32> @llvm.arm.neon.usmmla.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %b) #3
+  ret <4 x i32> %vusmmla1.i
+}
+
+define <2 x i32> @usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
+entry:
+; CHECK-LABEL: usdot.v2i32.v8i8
+; CHECK:        vusdot.s8       d0, d1, d2
+  %vusdot1.i = tail call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) #3
+  ret <2 x i32> %vusdot1.i
+}
+
+define <2 x i32> @usdot_lane.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
+entry:
+; CHECK-LABEL: usdot_lane.v2i32.v8i8
+; CHECK:        vusdot.s8       d0, d1, d2[0]
+  %0 = bitcast <8 x i8> %b to <2 x i32>
+  %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <2 x i32> zeroinitializer
+  %1 = bitcast <2 x i32> %shuffle to <8 x i8>
+  %vusdot1.i = tail call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %1) #3
+  ret <2 x i32> %vusdot1.i
+}
+
+define <2 x i32> @sudot_lane.v2i32.v8i8(<2 x i32> %r, <8 x i8> %a, <8 x i8> %b) {
+entry:
+; CHECK-LABEL: sudot_lane.v2i32.v8i8
+; CHECK:        vsudot.u8       d0, d1, d2[0]
+  %0 = bitcast <8 x i8> %b to <2 x i32>
+  %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <2 x i32> zeroinitializer
+  %1 = bitcast <2 x i32> %shuffle to <8 x i8>
+  %vusdot1.i = tail call <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32> %r, <8 x i8> %1, <8 x i8> %a) #3
+  ret <2 x i32> %vusdot1.i
+}
+
+define <4 x i32> @usdotq_lane.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <8 x i8> %b) {
+entry:
+; CHECK-LABEL: usdotq_lane.v4i32.v16i8
+; CHECK:        vusdot.s8       q0, q1, d4[0]
+  %0 = bitcast <8 x i8> %b to <2 x i32>
+  %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <4 x i32> zeroinitializer
+  %1 = bitcast <4 x i32> %shuffle to <16 x i8>
+  %vusdot1.i = tail call <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <16 x i8> %1) #3
+  ret <4 x i32> %vusdot1.i
+}
+
+define <4 x i32> @sudotq_lane.v4i32.v16i8(<4 x i32> %r, <16 x i8> %a, <8 x i8> %b) {
+entry:
+; CHECK-LABEL: sudotq_lane.v4i32.v16i8
+; CHECK:        vsudot.u8       q0, q1, d4[0]
+  %0 = bitcast <8 x i8> %b to <2 x i32>
+  %shuffle = shufflevector <2 x i32> %0, <2 x i32> undef, <4 x i32> zeroinitializer
+  %1 = bitcast <4 x i32> %shuffle to <16 x i8>
+  %vusdot1.i = tail call <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32> %r, <16 x i8> %1, <16 x i8> %a) #3
+  ret <4 x i32> %vusdot1.i
+}
+
+declare <4 x i32> @llvm.arm.neon.smmla.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2
+declare <4 x i32> @llvm.arm.neon.ummla.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2
+declare <4 x i32> @llvm.arm.neon.usmmla.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2
+declare <2 x i32> @llvm.arm.neon.usdot.v2i32.v8i8(<2 x i32>, <8 x i8>, <8 x i8>) #2
+declare <4 x i32> @llvm.arm.neon.usdot.v4i32.v16i8(<4 x i32>, <16 x i8>, <16 x i8>) #2


        


More information about the cfe-commits mailing list