[llvm] 2a4a229 - [WebAssembly] Custom optimization for truncate

Thomas Lively via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 14 08:42:45 PST 2021


Author: Jing Bao
Date: 2021-12-14T08:42:39-08:00
New Revision: 2a4a229d6dcceecbb8bab094b6880e2445a6e465

URL: https://github.com/llvm/llvm-project/commit/2a4a229d6dcceecbb8bab094b6880e2445a6e465
DIFF: https://github.com/llvm/llvm-project/commit/2a4a229d6dcceecbb8bab094b6880e2445a6e465.diff

LOG: [WebAssembly] Custom optimization for truncate

When possible, optimize TRUNCATE to generate Wasm SIMD narrow
instructions (i16x8.narrow_i32x4_u, i8x16.narrow_i16x8_u), rather than generate
lots of extract_lane and replace_lane.

Closes #50350.

Added: 
    

Modified: 
    llvm/lib/Target/WebAssembly/WebAssemblyISD.def
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
    llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 1fa0ea3867c7f..a3a33f4a5b3a3 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -31,6 +31,7 @@ HANDLE_NODETYPE(SWIZZLE)
 HANDLE_NODETYPE(VEC_SHL)
 HANDLE_NODETYPE(VEC_SHR_S)
 HANDLE_NODETYPE(VEC_SHR_U)
+HANDLE_NODETYPE(NARROW_U)
 HANDLE_NODETYPE(EXTEND_LOW_S)
 HANDLE_NODETYPE(EXTEND_LOW_U)
 HANDLE_NODETYPE(EXTEND_HIGH_S)

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 0c3ee545f8c55..38ed4c73fb935 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -176,6 +176,8 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     setTargetDAGCombine(ISD::FP_ROUND);
     setTargetDAGCombine(ISD::CONCAT_VECTORS);
 
+    setTargetDAGCombine(ISD::TRUNCATE);
+
     // Support saturating add for i8x16 and i16x8
     for (auto Op : {ISD::SADDSAT, ISD::UADDSAT})
       for (auto T : {MVT::v16i8, MVT::v8i16})
@@ -2609,6 +2611,114 @@ performVectorTruncZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getNode(Op, SDLoc(N), ResVT, Source);
 }
 
+// Helper to extract VectorWidth bits from Vec, starting from IdxVal.
+static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
+                                const SDLoc &DL, unsigned VectorWidth) {
+  EVT VT = Vec.getValueType();
+  EVT ElVT = VT.getVectorElementType();
+  unsigned Factor = VT.getSizeInBits() / VectorWidth;
+  EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
+                                  VT.getVectorNumElements() / Factor);
+
+  // Extract the relevant VectorWidth bits.  Generate an EXTRACT_SUBVECTOR
+  unsigned ElemsPerChunk = VectorWidth / ElVT.getSizeInBits();
+  assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
+
+  // This is the index of the first element of the VectorWidth-bit chunk
+  // we want. Since ElemsPerChunk is a power of 2 just need to clear bits.
+  IdxVal &= ~(ElemsPerChunk - 1);
+
+  // If the input is a buildvector just emit a smaller one.
+  if (Vec.getOpcode() == ISD::BUILD_VECTOR)
+    return DAG.getBuildVector(ResultVT, DL,
+                              Vec->ops().slice(IdxVal, ElemsPerChunk));
+
+  SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, DL);
+  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResultVT, Vec, VecIdx);
+}
+
+// Helper to recursively truncate vector elements in half with NARROW_U. DstVT
+// is the expected destination value type after recursion. In is the initial
+// input. Note that the input should have enough leading zero bits to prevent
+// NARROW_U from saturating results.
+static SDValue truncateVectorWithNARROW(EVT DstVT, SDValue In, const SDLoc &DL,
+                                        SelectionDAG &DAG) {
+  EVT SrcVT = In.getValueType();
+
+  // No truncation required, we might get here due to recursive calls.
+  if (SrcVT == DstVT)
+    return In;
+
+  unsigned SrcSizeInBits = SrcVT.getSizeInBits();
+  unsigned NumElems = SrcVT.getVectorNumElements();
+  if (!isPowerOf2_32(NumElems))
+    return SDValue();
+  assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
+  assert(SrcSizeInBits > DstVT.getSizeInBits() && "Illegal truncation");
+
+  LLVMContext &Ctx = *DAG.getContext();
+  EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
+
+  // Narrow to the largest type possible:
+  // vXi64/vXi32 -> i16x8.narrow_i32x4_u and vXi16 -> i8x16.narrow_i16x8_u.
+  EVT InVT = MVT::i16, OutVT = MVT::i8;
+  if (SrcVT.getScalarSizeInBits() > 16) {
+    InVT = MVT::i32;
+    OutVT = MVT::i16;
+  }
+  unsigned SubSizeInBits = SrcSizeInBits / 2;
+  InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits());
+  OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits());
+
+  // Split lower/upper subvectors.
+  SDValue Lo = extractSubVector(In, 0, DAG, DL, SubSizeInBits);
+  SDValue Hi = extractSubVector(In, NumElems / 2, DAG, DL, SubSizeInBits);
+
+  // 256bit -> 128bit truncate - Narrow lower/upper 128-bit subvectors.
+  if (SrcVT.is256BitVector() && DstVT.is128BitVector()) {
+    Lo = DAG.getBitcast(InVT, Lo);
+    Hi = DAG.getBitcast(InVT, Hi);
+    SDValue Res = DAG.getNode(WebAssemblyISD::NARROW_U, DL, OutVT, Lo, Hi);
+    return DAG.getBitcast(DstVT, Res);
+  }
+
+  // Recursively narrow lower/upper subvectors, concat result and narrow again.
+  EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2);
+  Lo = truncateVectorWithNARROW(PackedVT, Lo, DL, DAG);
+  Hi = truncateVectorWithNARROW(PackedVT, Hi, DL, DAG);
+
+  PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
+  SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi);
+  return truncateVectorWithNARROW(DstVT, Res, DL, DAG);
+}
+
+static SDValue performTruncateCombine(SDNode *N,
+                                      TargetLowering::DAGCombinerInfo &DCI) {
+  auto &DAG = DCI.DAG;
+
+  SDValue In = N->getOperand(0);
+  EVT InVT = In.getValueType();
+  if (!InVT.isSimple())
+    return SDValue();
+
+  EVT OutVT = N->getValueType(0);
+  if (!OutVT.isVector())
+    return SDValue();
+
+  EVT OutSVT = OutVT.getVectorElementType();
+  EVT InSVT = InVT.getVectorElementType();
+  // Currently only cover truncate to v16i8 or v8i16.
+  if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) &&
+        (OutSVT == MVT::i8 || OutSVT == MVT::i16) && OutVT.is128BitVector()))
+    return SDValue();
+
+  SDLoc DL(N);
+  APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(),
+                                    OutVT.getScalarSizeInBits());
+  In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT));
+  return truncateVectorWithNARROW(OutVT, In, DL, DAG);
+}
+
 SDValue
 WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
                                              DAGCombinerInfo &DCI) const {
@@ -2625,5 +2735,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::FP_ROUND:
   case ISD::CONCAT_VECTORS:
     return performVectorTruncZeroCombine(N, DCI);
+  case ISD::TRUNCATE:
+    return performTruncateCombine(N, DCI);
   }
 }

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 30b99c3a69a9a..5bb12c7fbdc71 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1278,6 +1278,14 @@ multiclass SIMDNarrow<Vec vec, bits<32> baseInst> {
 defm "" : SIMDNarrow<I16x8, 101>;
 defm "" : SIMDNarrow<I32x4, 133>;
 
+// WebAssemblyISD::NARROW_U
+def wasm_narrow_t : SDTypeProfile<1, 2, []>;
+def wasm_narrow_u : SDNode<"WebAssemblyISD::NARROW_U", wasm_narrow_t>;
+def : Pat<(v16i8 (wasm_narrow_u (v8i16 V128:$left), (v8i16 V128:$right))),
+          (NARROW_U_I8x16 $left, $right)>;
+def : Pat<(v8i16 (wasm_narrow_u (v4i32 V128:$left), (v4i32 V128:$right))),
+          (NARROW_U_I16x8 $left, $right)>;
+
 // Bitcasts are nops
 // Matching bitcast t1 to t1 causes strange errors, so avoid repeating types
 foreach t1 = AllVecs in

diff  --git a/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll b/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
index c1fd8ef01e38c..a595ffe51e2ed 100644
--- a/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
+++ b/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
@@ -532,7 +532,7 @@ entry:
 define <8 x i16> @stest_f16i16(<8 x half> %x) {
 ; CHECK-LABEL: stest_f16i16:
 ; CHECK:         .functype stest_f16i16 (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
-; CHECK-NEXT:    .local v128, v128
+; CHECK-NEXT:    .local v128, v128, v128
 ; CHECK-NEXT:  # %bb.0: # %entry
 ; CHECK-NEXT:    local.get 5
 ; CHECK-NEXT:    call __truncsfhf2
@@ -578,6 +578,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
 ; CHECK-NEXT:    v128.const -32768, -32768, -32768, -32768
 ; CHECK-NEXT:    local.tee 9
 ; CHECK-NEXT:    i32x4.max_s
+; CHECK-NEXT:    v128.const 65535, 65535, 65535, 65535
+; CHECK-NEXT:    local.tee 10
+; CHECK-NEXT:    v128.and
 ; CHECK-NEXT:    local.get 4
 ; CHECK-NEXT:    i32.trunc_sat_f32_s
 ; CHECK-NEXT:    i32x4.splat
@@ -594,7 +597,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
 ; CHECK-NEXT:    i32x4.min_s
 ; CHECK-NEXT:    local.get 9
 ; CHECK-NEXT:    i32x4.max_s
-; CHECK-NEXT:    i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT:    local.get 10
+; CHECK-NEXT:    v128.and
+; CHECK-NEXT:    i16x8.narrow_i32x4_u
 ; CHECK-NEXT:    # fallthrough-return
 entry:
   %conv = fptosi <8 x half> %x to <8 x i32>
@@ -666,7 +671,7 @@ define <8 x i16> @utesth_f16i16(<8 x half> %x) {
 ; CHECK-NEXT:    i32x4.replace_lane 3
 ; CHECK-NEXT:    local.get 8
 ; CHECK-NEXT:    i32x4.min_u
-; CHECK-NEXT:    i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT:    i16x8.narrow_i32x4_u
 ; CHECK-NEXT:    # fallthrough-return
 entry:
   %conv = fptoui <8 x half> %x to <8 x i32>
@@ -741,7 +746,7 @@ define <8 x i16> @ustest_f16i16(<8 x half> %x) {
 ; CHECK-NEXT:    i32x4.min_s
 ; CHECK-NEXT:    local.get 9
 ; CHECK-NEXT:    i32x4.max_s
-; CHECK-NEXT:    i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT:    i16x8.narrow_i32x4_u
 ; CHECK-NEXT:    # fallthrough-return
 entry:
   %conv = fptosi <8 x half> %x to <8 x i32>
@@ -2106,7 +2111,7 @@ entry:
 define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
 ; CHECK-LABEL: stest_f16i16_mm:
 ; CHECK:         .functype stest_f16i16_mm (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
-; CHECK-NEXT:    .local v128, v128
+; CHECK-NEXT:    .local v128, v128, v128
 ; CHECK-NEXT:  # %bb.0: # %entry
 ; CHECK-NEXT:    local.get 5
 ; CHECK-NEXT:    call __truncsfhf2
@@ -2152,6 +2157,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
 ; CHECK-NEXT:    v128.const -32768, -32768, -32768, -32768
 ; CHECK-NEXT:    local.tee 9
 ; CHECK-NEXT:    i32x4.max_s
+; CHECK-NEXT:    v128.const 65535, 65535, 65535, 65535
+; CHECK-NEXT:    local.tee 10
+; CHECK-NEXT:    v128.and
 ; CHECK-NEXT:    local.get 4
 ; CHECK-NEXT:    i32.trunc_sat_f32_s
 ; CHECK-NEXT:    i32x4.splat
@@ -2168,7 +2176,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
 ; CHECK-NEXT:    i32x4.min_s
 ; CHECK-NEXT:    local.get 9
 ; CHECK-NEXT:    i32x4.max_s
-; CHECK-NEXT:    i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT:    local.get 10
+; CHECK-NEXT:    v128.and
+; CHECK-NEXT:    i16x8.narrow_i32x4_u
 ; CHECK-NEXT:    # fallthrough-return
 entry:
   %conv = fptosi <8 x half> %x to <8 x i32>
@@ -2238,7 +2248,7 @@ define <8 x i16> @utesth_f16i16_mm(<8 x half> %x) {
 ; CHECK-NEXT:    i32x4.replace_lane 3
 ; CHECK-NEXT:    local.get 8
 ; CHECK-NEXT:    i32x4.min_u
-; CHECK-NEXT:    i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT:    i16x8.narrow_i32x4_u
 ; CHECK-NEXT:    # fallthrough-return
 entry:
   %conv = fptoui <8 x half> %x to <8 x i32>
@@ -2312,7 +2322,7 @@ define <8 x i16> @ustest_f16i16_mm(<8 x half> %x) {
 ; CHECK-NEXT:    i32x4.min_s
 ; CHECK-NEXT:    local.get 9
 ; CHECK-NEXT:    i32x4.max_s
-; CHECK-NEXT:    i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
+; CHECK-NEXT:    i16x8.narrow_i32x4_u
 ; CHECK-NEXT:    # fallthrough-return
 entry:
   %conv = fptosi <8 x half> %x to <8 x i32>


        


More information about the llvm-commits mailing list