[llvm] [AArch64][GlobalISel] Add custom legalization for v4s8 = G_TRUNC v4s16 (PR #85610)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 00:52:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-globalisel

Author: Amara Emerson (aemerson)

<details>
<summary>Changes</summary>

We see a *lot* of fallbacks these days due to <4 x s8> types appearing in truncates, and these seem to be commonly being used by the new load/store bitcasting -> s32 rule.

We can keep that load/store rule if we make sure to handle the truncates properly, and we adopt a similar strategy for this custom action as in DAG lowering's LowerTruncateVectorStore(). That is, we first widen the input <4 x s16> to <8 x s16>, so we can generate a legal G_TRUNC to <8 x s8>, and from there extract the final 32 bit sized value.

---
Full diff: https://github.com/llvm/llvm-project/pull/85610.diff


4 Files Affected:

- (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+36-1) 
- (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h (+1) 
- (added) llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir (+24) 
- (modified) llvm/test/CodeGen/AArch64/bitcast.ll (+17-11) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index 36adada2796531..04a228cf522ce7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -628,7 +628,8 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         return DstTy.isVector() && SrcTy.getSizeInBits() > 128 &&
                DstTy.getScalarSizeInBits() * 2 <= SrcTy.getScalarSizeInBits();
       })
-
+      .customIf(all(typeInSet(0, {v4s8}),
+                    typeInSet(1, {v4s16})))
       .alwaysLegal();
 
   getActionDefinitionsBuilder(G_SEXT_INREG)
@@ -1262,11 +1263,45 @@ bool AArch64LegalizerInfo::legalizeCustom(
     return legalizeDynStackAlloc(MI, Helper);
   case TargetOpcode::G_PREFETCH:
     return legalizePrefetch(MI, Helper);
+  case TargetOpcode::G_TRUNC:
+    return legalizeTrunc(MI, Helper);
   }
 
   llvm_unreachable("expected switch to return");
 }
 
+bool AArch64LegalizerInfo::legalizeTrunc(MachineInstr &MI,
+                                         LegalizerHelper &Helper) const {
+  assert(MI.getOpcode() == TargetOpcode::G_TRUNC);
+
+  // Handle <4 x s8> = G_TRUNC <4 x s16> by widening to <8 x s16> first.
+  // So the sequence is:
+  // %orig_val(<4 x s16>) = ...
+  // %wide = G_MERGE_VALUES %orig_val, %undef:_(<4 x s16>)
+  // %wide_trunc:_(<8 x s8>) = G_TRUNC %wide
+  // %bc:_(<2 x s32>) = G_BITCAST %wide_trunc
+  // %eve:_(s32) = G_EXTRACT_VECTOR_ELT %bc, 0
+  // %final:_(<4 x s8>) = G_BITCAST %eve
+
+  MachineIRBuilder &MIB = Helper.MIRBuilder;
+
+  auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
+  assert(DstTy == LLT::fixed_vector(4, LLT::scalar(8)) &&
+         SrcTy == LLT::fixed_vector(4, LLT::scalar(16)));
+
+  auto WideTy = LLT::fixed_vector(8, LLT::scalar(16));
+  auto Undef = MIB.buildUndef(SrcTy);
+  auto Merge = MIB.buildMergeLikeInstr(WideTy, {SrcReg, Undef});
+  auto Trunc = MIB.buildTrunc(LLT::fixed_vector(8, LLT::scalar(8)), Merge);
+  auto BC = MIB.buildBitcast(LLT::fixed_vector(2, LLT::scalar(32)), Trunc);
+  auto Extract = MIB.buildExtractVectorElement(
+      LLT::scalar(32), BC, MIB.buildConstant(LLT::scalar(32), 0));
+  MIB.buildBitcast(DstReg, Extract);
+
+  MI.eraseFromParent();
+  return true;
+}
+
 bool AArch64LegalizerInfo::legalizeFunnelShift(MachineInstr &MI,
                                                MachineRegisterInfo &MRI,
                                                MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
index b69d9b015bd2b3..e9d8b54de9ef70 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h
@@ -64,6 +64,7 @@ class AArch64LegalizerInfo : public LegalizerInfo {
                                 LegalizerHelper &Helper) const;
   bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
   bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
+  bool legalizeTrunc(MachineInstr &MI, LegalizerHelper &Helper) const;
   const AArch64Subtarget *ST;
 };
 } // End llvm namespace.
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir
new file mode 100644
index 00000000000000..e3e12558c2242a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/GlobalISel/legalize-trunc.mir
@@ -0,0 +1,24 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -mtriple=aarch64  -run-pass=legalizer %s -o - | FileCheck %s
+---
+name:            trunc_v4s8_v4s16
+body:             |
+  bb.1:
+    liveins: $x0
+    ; CHECK-LABEL: name: trunc_v4s8_v4s16
+    ; CHECK: liveins: $x0
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: %in:_(<4 x s16>) = COPY $x0
+    ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<4 x s16>) = G_IMPLICIT_DEF
+    ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<8 x s16>) = G_CONCAT_VECTORS %in(<4 x s16>), [[DEF]](<4 x s16>)
+    ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(<8 x s8>) = G_TRUNC [[CONCAT_VECTORS]](<8 x s16>)
+    ; CHECK-NEXT: [[BITCAST:%[0-9]+]]:_(<2 x s32>) = G_BITCAST [[TRUNC]](<8 x s8>)
+    ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 0
+    ; CHECK-NEXT: [[EVEC:%[0-9]+]]:_(s32) = G_EXTRACT_VECTOR_ELT [[BITCAST]](<2 x s32>), [[C]](s64)
+    ; CHECK-NEXT: %trunc:_(<4 x s8>) = G_BITCAST [[EVEC]](s32)
+    ; CHECK-NEXT: $s0 = COPY %trunc(<4 x s8>)
+    %in:_(<4 x s16>) = COPY $x0
+    %trunc:_(<4 x s8>) = G_TRUNC %in
+    $s0 = COPY %trunc
+
+...
diff --git a/llvm/test/CodeGen/AArch64/bitcast.ll b/llvm/test/CodeGen/AArch64/bitcast.ll
index 9ebd570e687a01..2b7065fe450617 100644
--- a/llvm/test/CodeGen/AArch64/bitcast.ll
+++ b/llvm/test/CodeGen/AArch64/bitcast.ll
@@ -4,8 +4,7 @@
 
 ; PR23065: SCALAR_TO_VECTOR implies the top elements 1 to N-1 of the N-element vector are undefined.
 
-; CHECK-GI:         warning: Instruction selection used fallback path for bitcast_v4i8_i32
-; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_i32_v4i8
+; CHECK-GI:         warning: Instruction selection used fallback path for bitcast_i32_v4i8
 ; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_v2i16_i32
 ; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_i32_v2i16
 ; CHECK-GI-NEXT:    warning: Instruction selection used fallback path for bitcast_v2i16_v4i8
@@ -54,15 +53,22 @@ define <4 x i16> @foo2(<2 x i32> %a) {
 ; ===== To and From Scalar Types =====
 
 define i32 @bitcast_v4i8_i32(<4 x i8> %a, <4 x i8> %b){
-; CHECK-LABEL: bitcast_v4i8_i32:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    sub sp, sp, #16
-; CHECK-NEXT:    .cfi_def_cfa_offset 16
-; CHECK-NEXT:    add v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
-; CHECK-NEXT:    fmov w0, s0
-; CHECK-NEXT:    add sp, sp, #16
-; CHECK-NEXT:    ret
+; CHECK-SD-LABEL: bitcast_v4i8_i32:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    sub sp, sp, #16
+; CHECK-SD-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-SD-NEXT:    add v0.4h, v0.4h, v1.4h
+; CHECK-SD-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
+; CHECK-SD-NEXT:    fmov w0, s0
+; CHECK-SD-NEXT:    add sp, sp, #16
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: bitcast_v4i8_i32:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    add v0.4h, v0.4h, v1.4h
+; CHECK-GI-NEXT:    uzp1 v0.8b, v0.8b, v0.8b
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
   %c = add <4 x i8> %a, %b
   %d = bitcast <4 x i8> %c to i32
   ret i32 %d

``````````

</details>


https://github.com/llvm/llvm-project/pull/85610


More information about the llvm-commits mailing list