[llvm] [GISEL] Add G_SPLAT_VECTOR_PARTS to represent 64-bit splat vectors on… (PR #86970)
Michael Maitland via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 28 09:53:06 PDT 2024
https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/86970
>From 7ac03f101567daa91bb6de8bd5392081ee0f968a Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Thu, 28 Mar 2024 09:20:10 -0700
Subject: [PATCH] [GISEL] Add G_SPLAT_VECTOR_PARTS to represent 64-bit splat
vectors on i32 targets
We'd like to be able to represent the construction of a splat vector
when the target has 32-bit integers but supports 64 bit vectors. This
opcode allows us to represent that. It is the equivalent of
ISD::SPLAT_VECTOR_PARTS. The ISD version takes a list of scalars, but
this opcode accepts two scalars.
---
llvm/docs/GlobalISel/GenericOpcode.rst | 9 ++++
.../CodeGen/GlobalISel/MachineIRBuilder.h | 14 +++++
llvm/include/llvm/Support/TargetOpcodes.def | 3 ++
llvm/include/llvm/Target/GenericOpcodes.td | 10 ++++
.../CodeGen/GlobalISel/MachineIRBuilder.cpp | 11 ++++
llvm/lib/CodeGen/MachineVerifier.cpp | 22 ++++++++
.../test_g_splat_vector_parts.mir | 53 +++++++++++++++++++
7 files changed, 122 insertions(+)
create mode 100644 llvm/test/MachineVerifier/test_g_splat_vector_parts.mir
diff --git a/llvm/docs/GlobalISel/GenericOpcode.rst b/llvm/docs/GlobalISel/GenericOpcode.rst
index cae2c21b80d7e7..b0b9bce3fc90b0 100644
--- a/llvm/docs/GlobalISel/GenericOpcode.rst
+++ b/llvm/docs/GlobalISel/GenericOpcode.rst
@@ -690,6 +690,15 @@ G_SPLAT_VECTOR
Create a vector where all elements are the scalar from the source operand.
+G_SPLAT_VECTOR_PARTS
+^^^^^^^^^^^^^^^^^^^^
+
+Create a vector where all elements are the scalar created by joining the
+operands together. This allows representing 64-bit splat on a target with 32-bit
+integers. The total width of the scalars must cover the element width exactly.
+The lo operand contains the least significant bits and the hi operand contains
+the most significant bits.
+
Vector Reduction Operations
---------------------------
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
index 16a7fc446fbe1d..8000e610447ce9 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
@@ -1107,6 +1107,20 @@ class MachineIRBuilder {
/// \return a MachineInstrBuilder for the newly created instruction.
MachineInstrBuilder buildSplatVector(const DstOp &Res, const SrcOp &Val);
+ /// Build and insert \p Res = G_SPLAT_VECTOR_PARTS \p Lo Hi.
+ ///
+ /// \p Lo contains the least significant bits of the value. \p Hi contains the
+ /// most significant bits of the value.
+ ///
+ /// \pre setBasicBlock or setMI must have been called.
+ /// \pre \p Res must be a generic virtual register with vector type.
+ /// \pre \p Lo must be a generic virtual register with scalar type.
+ /// \pre \p Hi must be a generic virtual register with scalar type.
+ ///
+ /// \return a MachineInstrBuilder for the newly created instruction.
+ MachineInstrBuilder buildSplatVectorParts(const DstOp &Res, const SrcOp &Lo,
+ const SrcOp &Hi);
+
/// Build and insert \p Res = G_CONCAT_VECTORS \p Op0, ...
///
/// G_CONCAT_VECTORS creates a vector from the concatenation of 2 or more
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index 5765926d6d93d3..b01622be31ef09 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -748,6 +748,9 @@ HANDLE_TARGET_OPCODE(G_SHUFFLE_VECTOR)
/// Generic splatvector.
HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR)
+/// Generic splatvector parts.
+HANDLE_TARGET_OPCODE(G_SPLAT_VECTOR_PARTS)
+
/// Generic count trailing zeroes.
HANDLE_TARGET_OPCODE(G_CTTZ)
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index d0f471eb29b6fd..6d6eade99ca314 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1480,6 +1480,16 @@ def G_SPLAT_VECTOR: GenericInstruction {
let hasSideEffects = false;
}
+// Generic splatvector parts. This allows representing 64-bit splat on a target
+// with 32-bit integers. The total width of the scalars must cover the element
+// width. The lo operand contains the least significant bits and the hi operand
+// contains the most significant bits.
+def G_SPLAT_VECTOR_PARTS : GenericInstruction {
+ let OutOperandList = (outs type0:$dst);
+ let InOperandList = (ins type1:$lo, type1:$hi);
+ let hasSideEffects = false;
+}
+
//------------------------------------------------------------------------------
// Vector reductions
//------------------------------------------------------------------------------
diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
index 07d4cb5eaa23c8..f025e631aa0b2f 100644
--- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
@@ -749,6 +749,17 @@ MachineInstrBuilder MachineIRBuilder::buildSplatVector(const DstOp &Res,
return buildInstr(TargetOpcode::G_SPLAT_VECTOR, Res, Src);
}
+MachineInstrBuilder MachineIRBuilder::buildSplatVectorParts(const DstOp &Res,
+ const SrcOp &Lo,
+ const SrcOp &Hi) {
+ TypeSize LoSize = Lo.getLLTTy(*getMRI()).getSizeInBits();
+ TypeSize HiSize = Hi.getLLTTy(*getMRI()).getSizeInBits();
+ TypeSize EltSize = Res.getLLTTy(*getMRI()).getElementType().getSizeInBits();
+ assert(LoSize + HiSize == EltSize &&
+ "Expected scalar sizes to cover Dst element size");
+ return buildInstr(TargetOpcode::G_SPLAT_VECTOR_PARTS, {Res}, {Lo, Hi});
+}
+
MachineInstrBuilder MachineIRBuilder::buildShuffleVector(const DstOp &Res,
const SrcOp &Src1,
const SrcOp &Src2,
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index e4e05ce9278caf..bb0b7199aa85e6 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -1781,6 +1781,28 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
break;
}
+ case TargetOpcode::G_SPLAT_VECTOR_PARTS: {
+ LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
+ LLT LoTy = MRI->getType(MI->getOperand(1).getReg());
+ LLT HiTy = MRI->getType(MI->getOperand(2).getReg());
+
+ if (!DstTy.isScalableVector()) {
+ report("Destination type must be a scalable vector", MI);
+ break;
+ }
+
+ if (!LoTy.isScalar() || !HiTy.isScalar()) {
+ report("Source types must be scalar", MI);
+ break;
+ }
+
+ if (LoTy.getSizeInBits() + HiTy.getSizeInBits() != DstTy.getSizeInBits()) {
+ report("Source types must cover the element type", MI);
+ break;
+ }
+
+ break;
+ }
case TargetOpcode::G_DYN_STACKALLOC: {
const MachineOperand &DstOp = MI->getOperand(0);
const MachineOperand &AllocOp = MI->getOperand(1);
diff --git a/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir b/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir
new file mode 100644
index 00000000000000..4205acebc67879
--- /dev/null
+++ b/llvm/test/MachineVerifier/test_g_splat_vector_parts.mir
@@ -0,0 +1,53 @@
+# RUN: not --crash llc -o - -mtriple=arm64 -run-pass=none -verify-machineinstrs %s 2>&1 | FileCheck %s
+# REQUIRES: aarch64-registered-target
+---
+name: g_splat_vector_parts
+tracksRegLiveness: true
+liveins:
+body: |
+ bb.0:
+ %0:_(s32) = G_CONSTANT i32 0
+ %1:_(<2 x s32>) = G_IMPLICIT_DEF
+ %2:_(<vscale x 2 x s32>) = G_IMPLICIT_DEF
+
+ ; CHECK: Destination type must be a scalable vector
+ %3:_(s32) = G_SPLAT_VECTOR_PARTS %0, %0
+
+ ; CHECK: Destination type must be a scalable vector
+ %4:_(<2 x s32>) = G_SPLAT_VECTOR_PARTS %0, %0
+
+ ; CHECK: Source types must be scalar
+ %5:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %1, %0
+
+ ; CHECK: Source types must be scalar
+ %6:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %0, %1
+
+ ; CHECK: Source types must be scalar
+ %7:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %1, %1
+
+ ; CHECK: Source types must be scalar
+ %8:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %0, %2
+
+ ; CHECK: Source types must be scalar
+ %9:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %2, %0
+
+ ; CHECK: Source types must be scalar
+ %10:_(<vscale x 2 x s32>) = G_SPLAT_VECTOR_PARTS %2, %2
+
+ %11:_(s16) = G_CONSTANT i16 0
+
+ ; CHECK: Source types must cover the element type
+ %12:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %11, %11
+
+ ; CHECK: Source types must cover the element type
+ %13:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %11, %0
+
+ %14:_(s64) = G_CONSTANT i64 0
+
+ ; CHECK: Source types must cover the element type
+ %15:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %14, %14
+
+ ; CHECK: Source types must cover the element type
+ %16:_(<vscale x 2 x s64>) = G_SPLAT_VECTOR_PARTS %14, %0
+
+...
More information about the llvm-commits
mailing list