[llvm] [GISEL] Add G_SPLAT_VECTOR_PARTS to represent 64-bit splat vectors on… (PR #86970)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 28 09:53:06 PDT 2024


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/86970

>From 7ac03f101567daa91bb6de8bd5392081ee0f968a Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Thu, 28 Mar 2024 09:20:10 -0700
Subject: [PATCH] [GISEL] Add G_SPLAT_VECTOR_PARTS to represent 64-bit splat
 vectors on i32 targets

We'd like to be able to represent the construction of a splat vector
when the target has 32-bit integers but supports 64 bit vectors. This
opcode allows us to represent that. It is the equivalent of
ISD::SPLAT_VECTOR_PARTS. The ISD version takes a list of scalars, but
this opcode accepts two scalars.
---
 llvm/docs/GlobalISel/GenericOpcode.rst        |  9 ++++
 .../CodeGen/GlobalISel/MachineIRBuilder.h     | 14 +++++
 llvm/include/llvm/Support/TargetOpcodes.def   |  3 ++
 llvm/include/llvm/Target/GenericOpcodes.td    | 10 ++++
 .../CodeGen/GlobalISel/MachineIRBuilder.cpp   | 11 ++++
 llvm/lib/CodeGen/MachineVerifier.cpp          | 22 ++++++++
 .../test_g_splat_vector_parts.mir             | 53 +++++++++++++++++++
 7 files changed, 122 insertions(+)
 create mode 100644 llvm/test/MachineVerifier/test_g_splat_vector_parts.mir

diff --git a/llvm/docs/GlobalISel/GenericOpcode.rst b/llvm/docs/GlobalISel/GenericOpcode.rst
index cae2c21b80d7e7..b0b9bce3fc90b0 100644
--- a/llvm/docs/GlobalISel/GenericOpcode.rst
+++ b/llvm/docs/GlobalISel/GenericOpcode.rst
@@ -690,6 +690,15 @@ G_SPLAT_VECTOR
 
 Create a vector where all elements are the scalar from the source operand.
 
+G_SPLAT_VECTOR_PARTS
+^^^^^^^^^^^^^^^^^^^^
+
+Create a vector where all elements are the scalar created by joining the
+operands together. This allows representing 64-bit splat on a target with 32-bit
+integers. The total width of the scalars must cover the element width exactly.
+The lo operand contains the least significant bits and the hi operand contains
+the most significant bits.
+
 Vector Reduction Operations
 ---------------------------
 
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
index 16a7fc446fbe1d..8000e610447ce9 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
@@ -1107,6 +1107,20 @@ class MachineIRBuilder {
   /// \return a MachineInstrBuilder for the newly created instruction.
   MachineInstrBuilder buildSplatVector(const DstOp &Res, const SrcOp &Val);
 
+  /// Build and insert \p Res = G_SPLAT_VECTOR_PARTS \p Lo Hi.
+  ///
+  /// \p Lo contains the least significant bits of the value. \p Hi contains the
+  /// most significant bits of the value.
+  ///
+  /// \pre setBasicBlock or setMI must have been called.
+  /// \pre \p Res must be a generic virtual register with vector type.
+  /// \pre \p Lo must be a generic virtual register with scalar type.
+  /// \pre \p Hi must be a generic virtual register with scalar type.
+  ///
+  /// \return a MachineInstrBuilder for the newly created instruction.
+  MachineInstrBuilder buildSplatVectorParts(const DstOp &Res, const SrcOp &Lo,
+                                            const SrcOp &Hi);
+
   /// Build and insert \p Res = G_CONCAT_VECTORS \p Op0, ...
   ///
   /// G_CONCAT_VECTORS creates a vector from the concatenation of 2 or more
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 5765926d6d93d3..b01622be31ef09 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -748,6 +748,9 @@ HANDLE_TARGET_OPCODE(G_SHUFFLE_VECTOR)
 /// Generic splatvector.
 HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR)
 
+/// Generic splatvector parts.
+HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR_PARTS)
+
 /// Generic count trailing zeroes.
 HANDLE_TARGET_OPCODE(G_CTTZ)
 
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index d0f471eb29b6fd..6d6eade99ca314 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1480,6 +1480,16 @@ def G_SPLAT_VECTOR: GenericInstruction {
   let hasSideEffects = false;
 }
 
+// Generic splatvector parts. This allows representing 64-bit splat on a target
+// with 32-bit integers. The total width of the scalars must cover the element
+// width. The lo operand contains the least significant bits and the hi operand
+// contains the most significant bits.
+def G_SPLAT_VECTOR_PARTS : GenericInstruction {
+  let OutOperandList = (outs type0:$dst);
+  let InOperandList = (ins type1:$lo, type1:$hi);
+  let hasSideEffects = false;
+}
+
 //------------------------------------------------------------------------------
 // Vector reductions
 //------------------------------------------------------------------------------
diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
index 07d4cb5eaa23c8..f025e631aa0b2f 100644
--- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
@@ -749,6 +749,17 @@ MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
   return buildInstr(TargetOpcode::G_SPLAT_VECTOR, Res, Src);
 }
 
+MachineInstrBuilder MachineIRBuilder::buildSplatVectorParts(const DstOp &Res,
+                                                            const SrcOp &Lo,
+                                                            const SrcOp &Hi) {
+  TypeSize LoSize = Lo.getLLTTy(*getMRI()).getSizeInBits();
+  TypeSize HiSize = Hi.getLLTTy(*getMRI()).getSizeInBits();
+  TypeSize EltSize = Res.getLLTTy(*getMRI()).getElementType().getSizeInBits();
+  assert(LoSize + HiSize == EltSize &&
+         "Expected scalar sizes to cover Dst element size");
+  return buildInstr(TargetOpcode::G_SPLAT_VECTOR_PARTS, {Res}, {Lo, Hi});
+}
+
 MachineInstrBuilder MachineIRBuilder::buildShuffleVector(const DstOp &Res,
                                                          const SrcOp &Src1,
                                                          const SrcOp &Src2,
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index e4e05ce9278caf..bb0b7199aa85e6 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -1781,6 +1781,28 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
 
     break;
   }
+  case TargetOpcode::G_SPLAT_VECTOR_PARTS: {
+    LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
+    LLT LoTy = MRI->getType(MI->getOperand(1).getReg());
+    LLT HiTy = MRI->getType(MI->getOperand(2).getReg());
+
+    if (!DstTy.isScalableVector()) {
+      report("Destination type must be a scalable vector", MI);
+      break;
+    }
+
+    if (!LoTy.isScalar() || !HiTy.isScalar()) {
+      report("Source types must be scalar", MI);
+      break;
+    }
+
+    if (LoTy.getSizeInBits() + HiTy.getSizeInBits() != DstTy.getSizeInBits()) {
+      report("Source types must cover the element type", MI);
+      break;
+    }
+
+    break;
+  }
   case TargetOpcode::G_DYN_STACKALLOC: {
     const MachineOperand &DstOp = MI->getOperand(0);
     const MachineOperand &AllocOp = MI->getOperand(1);
diff --git a/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir b/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir
new file mode 100644
index 00000000000000..4205acebc67879
--- /dev/null
+++ b/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir
@@ -0,0 +1,53 @@
+# RUN: not --crash llc -o - -mtriple=arm64 -run-pass=none -verify-machineinstrs %s 2>&1 | FileCheck %s
+# REQUIRES: aarch64-registered-target
+---
+name:            g_splat_vector_parts
+tracksRegLiveness: true
+liveins:
+body:             |
+  bb.0:
+    %0:_(s32) = G_CONSTANT i32 0
+    %1:_(<2 x s32>) = G_IMPLICIT_DEF
+    %2:_(<vscale x 2 x s32>) = G_IMPLICIT_DEF
+
+    ; CHECK: Destination type must be a scalable vector
+    %3:_(s32) = G_SPLAT_VECTOR_PARTS %0, %0
+
+    ; CHECK: Destination type must be a scalable vector
+    %4:_(<2 x s32>) = G_SPLAT_VECTOR_PARTS %0, %0
+
+    ; CHECK: Source types must be scalar
+    %5:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %1, %0
+
+    ; CHECK: Source types must be scalar
+    %6:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %0, %1
+
+    ; CHECK: Source types must be scalar
+    %7:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %1, %1
+
+    ; CHECK: Source types must be scalar
+    %8:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %0, %2
+
+    ; CHECK: Source types must be scalar
+    %9:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %2, %0
+
+    ; CHECK: Source types must be scalar
+    %10:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %2, %2
+
+    %11:_(s16) = G_CONSTANT i16 0
+
+    ; CHECK: Source types must cover the element type
+    %12:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %11, %11
+
+    ; CHECK: Source types must cover the element type
+    %13:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %11, %0
+
+    %14:_(s64) = G_CONSTANT i64 0
+
+    ; CHECK: Source types must cover the element type
+    %15:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %14, %14
+
+    ; CHECK: Source types must cover the element type
+    %16:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %14, %0
+
+...



More information about the llvm-commits mailing list