[llvm] matrix mul and transpose (PR #172050)

Steven Perron via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 12 09:28:52 PST 2025


https://github.com/s-perron created https://github.com/llvm/llvm-project/pull/172050

- **[SPIR-V] Legalize vector arithmetic and intrinsics for large vectors**
- **Add s64 to all scalars.**
- **Fix alignment in stores generated in legalize pointer cast.**
- **Set all def to default type in post legalization.**
- **Update tests, and add tests for 6-element vectors**
- **Handle G_STRICT_FMA**
- **Scalarize rem and div with illegal vector size that is not a power of 2.**
- **[SPIRV] Implement lowering for llvm.matrix.transpose and llvm.matrix.multiply**


>From d8aa9c91fb7bc6df9617b3251387206c536632c3 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Thu, 30 Oct 2025 13:01:44 -0400
Subject: [PATCH 1/8] [SPIR-V] Legalize vector arithmetic and intrinsics for
 large vectors
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch improves the legalization of vector operations, particularly
focusing on vectors that exceed the maximum supported size (e.g., 4 elements
for shaders). This includes better handling for insert and extract element
operations, which facilitates the legalization of loads and stores for
long vectors—a common pattern when compiling HLSL matrices with Clang.

Key changes include:
- Adding legalization rules for G_FMA, G_INSERT_VECTOR_ELT, and various
  arithmetic operations to handle splitting of large vectors.
- Updating G_CONCAT_VECTORS and G_SPLAT_VECTOR to be legal for allowed
  types.
- Implementing custom legalization for G_INSERT_VECTOR_ELT using the
  spv_insertelt intrinsic.
- Enhancing SPIRVPostLegalizer to deduce types for arithmetic instructions
  and vector element intrinsics (spv_insertelt, spv_extractelt).
- Refactoring legalizeIntrinsic to uniformly handle vector legalization
  requirements.

The strategy for insert and extract operations mirrors that of bitcasts:
incoming intrinsics are converted to generic MIR instructions (G_INSERT_VECTOR_ELT
and G_EXTRACT_VECTOR_ELT) to leverage standard legalization rules (like splitting).
After legalization, they are converted back to their respective SPIR-V intrinsics
(spv_insertelt, spv_extractelt) because later passes in the backend expect these
intrinsics rather than the generic instructions.

This ensures that operations on large vectors (e.g., <16 x float>) are
correctly broken down into legal sub-vectors.
---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  | 117 ++++++++---
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  |  47 ++++-
 .../SPIRV/legalization/load-store-global.ll   | 194 ++++++++++++++++++
 .../SPIRV/legalization/vector-arithmetic.ll   | 149 ++++++++++++++
 4 files changed, 464 insertions(+), 43 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index b5912c27316c9..4d83649c0f84f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -113,6 +113,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                            v3s1, v3s8, v3s16, v3s32, v3s64,
                            v4s1, v4s8, v4s16, v4s32, v4s64};
 
+  auto allScalars = {s1, s8, s16, s32};
+
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,
       v3s1, v3s8, v3s16, v3s32, v3s64, v4s1,  v4s8,  v4s16,  v4s32,  v4s64,
@@ -172,9 +174,25 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   for (auto Opc : getTypeFoldingSupportedOpcodes()) {
     if (Opc != G_EXTRACT_VECTOR_ELT)
-      getActionDefinitionsBuilder(Opc).custom();
+      getActionDefinitionsBuilder(Opc)
+          .customFor(allScalars)
+          .customFor(allowedVectorTypes)
+          .moreElementsToNextPow2(0)
+          .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                           LegalizeMutations::changeElementCountTo(
+                               0, ElementCount::getFixed(MaxVectorSize)))
+          .custom();
   }
 
+  getActionDefinitionsBuilder(TargetOpcode::G_FMA)
+      .legalFor(allScalars)
+      .legalFor(allowedVectorTypes)
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .alwaysLegal();
+
   getActionDefinitionsBuilder(G_INTRINSIC_W_SIDE_EFFECTS).custom();
 
   getActionDefinitionsBuilder(G_SHUFFLE_VECTOR)
@@ -192,6 +210,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                            1, ElementCount::getFixed(MaxVectorSize)))
       .custom();
 
+  getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
+      .moreElementsToNextPow2(0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .custom();
+
   // Illegal G_UNMERGE_VALUES instructions should be handled
   // during the combine phase.
   getActionDefinitionsBuilder(G_BUILD_VECTOR)
@@ -215,14 +240,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
       .custom();
 
+  // If the result is still illegal, the combiner should be able to remove it.
   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
-      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
-      .moreElementsToNextPow2(0)
-      .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
-      .alwaysLegal();
+      .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
+      .moreElementsToNextPow2(0);
 
   getActionDefinitionsBuilder(G_SPLAT_VECTOR)
-      .legalIf(vectorElementCountIsLessThanOrEqualTo(0, MaxVectorSize))
+      .legalFor(allowedVectorTypes)
       .moreElementsToNextPow2(0)
       .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
                        LegalizeMutations::changeElementSizeTo(0, MaxVectorSize))
@@ -458,6 +482,23 @@ static bool legalizeExtractVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
   return true;
 }
 
+static bool legalizeInsertVectorElt(LegalizerHelper &Helper, MachineInstr &MI,
+                                    SPIRVGlobalRegistry *GR) {
+  MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
+  Register DstReg = MI.getOperand(0).getReg();
+  Register SrcReg = MI.getOperand(1).getReg();
+  Register ValReg = MI.getOperand(2).getReg();
+  Register IdxReg = MI.getOperand(3).getReg();
+
+  MIRBuilder
+      .buildIntrinsic(Intrinsic::spv_insertelt, ArrayRef<Register>{DstReg})
+      .addUse(SrcReg)
+      .addUse(ValReg)
+      .addUse(IdxReg);
+  MI.eraseFromParent();
+  return true;
+}
+
 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
                                 LegalizerHelper &Helper,
                                 MachineRegisterInfo &MRI,
@@ -483,6 +524,8 @@ bool SPIRVLegalizerInfo::legalizeCustom(
     return legalizeBitcast(Helper, MI);
   case TargetOpcode::G_EXTRACT_VECTOR_ELT:
     return legalizeExtractVectorElt(Helper, MI, GR);
+  case TargetOpcode::G_INSERT_VECTOR_ELT:
+    return legalizeInsertVectorElt(Helper, MI, GR);
   case TargetOpcode::G_INTRINSIC:
   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
     return legalizeIntrinsic(Helper, MI);
@@ -512,6 +555,15 @@ bool SPIRVLegalizerInfo::legalizeCustom(
   }
 }
 
+static bool needsVectorLegalization(const LLT &Ty, const SPIRVSubtarget &ST) {
+  if (!Ty.isVector())
+    return false;
+  unsigned NumElements = Ty.getNumElements();
+  unsigned MaxVectorSize = ST.isShader() ? 4 : 16;
+  return (NumElements > 4 && !isPowerOf2_32(NumElements)) ||
+         NumElements > MaxVectorSize;
+}
+
 bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
                                            MachineInstr &MI) const {
   LLVM_DEBUG(dbgs() << "legalizeIntrinsic: " << MI);
@@ -528,41 +580,38 @@ bool SPIRVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
     LLT DstTy = MRI.getType(DstReg);
     LLT SrcTy = MRI.getType(SrcReg);
 
-    int32_t MaxVectorSize = ST.isShader() ? 4 : 16;
-
-    bool DstNeedsLegalization = false;
-    bool SrcNeedsLegalization = false;
-
-    if (DstTy.isVector()) {
-      if (DstTy.getNumElements() > 4 &&
-          !isPowerOf2_32(DstTy.getNumElements())) {
-        DstNeedsLegalization = true;
-      }
-
-      if (DstTy.getNumElements() > MaxVectorSize) {
-        DstNeedsLegalization = true;
-      }
-    }
-
-    if (SrcTy.isVector()) {
-      if (SrcTy.getNumElements() > 4 &&
-          !isPowerOf2_32(SrcTy.getNumElements())) {
-        SrcNeedsLegalization = true;
-      }
-
-      if (SrcTy.getNumElements() > MaxVectorSize) {
-        SrcNeedsLegalization = true;
-      }
-    }
-
     // If an spv_bitcast needs to be legalized, we convert it to G_BITCAST to
     // allow using the generic legalization rules.
-    if (DstNeedsLegalization || SrcNeedsLegalization) {
+    if (needsVectorLegalization(DstTy, ST) ||
+        needsVectorLegalization(SrcTy, ST)) {
       LLVM_DEBUG(dbgs() << "Replacing with a G_BITCAST\n");
       MIRBuilder.buildBitcast(DstReg, SrcReg);
       MI.eraseFromParent();
     }
     return true;
+  } else if (IntrinsicID == Intrinsic::spv_insertelt) {
+    Register DstReg = MI.getOperand(0).getReg();
+    LLT DstTy = MRI.getType(DstReg);
+
+    if (needsVectorLegalization(DstTy, ST)) {
+      Register SrcReg = MI.getOperand(2).getReg();
+      Register ValReg = MI.getOperand(3).getReg();
+      Register IdxReg = MI.getOperand(4).getReg();
+      MIRBuilder.buildInsertVectorElement(DstReg, SrcReg, ValReg, IdxReg);
+      MI.eraseFromParent();
+    }
+    return true;
+  } else if (IntrinsicID == Intrinsic::spv_extractelt) {
+    Register SrcReg = MI.getOperand(2).getReg();
+    LLT SrcTy = MRI.getType(SrcReg);
+
+    if (needsVectorLegalization(SrcTy, ST)) {
+      Register DstReg = MI.getOperand(0).getReg();
+      Register IdxReg = MI.getOperand(3).getReg();
+      MIRBuilder.buildExtractVectorElement(DstReg, SrcReg, IdxReg);
+      MI.eraseFromParent();
+    }
+    return true;
   }
   return true;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index c90e6d8cfbfb4..d91016a38539b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -16,6 +16,7 @@
 #include "SPIRV.h"
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
+#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
 #include <stack>
@@ -66,8 +67,9 @@ static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
     for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
       for (const auto &Use :
            MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
-        assert(Use.getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
-               "Expected use of G_UNMERGE_VALUES to be a G_BUILD_VECTOR");
+        if (Use.getOpcode() != TargetOpcode::G_BUILD_VECTOR)
+          continue;
+
         if (auto *VecType =
                 GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
           ScalarType = GR->getScalarOrVectorComponentType(VecType);
@@ -133,10 +135,10 @@ static SPIRVType *deduceTypeFromOperandRange(MachineInstr *I,
   return ResType;
 }
 
-static SPIRVType *deduceTypeForResultRegister(MachineInstr *Use,
-                                              Register UseRegister,
-                                              SPIRVGlobalRegistry *GR,
-                                              MachineIRBuilder &MIB) {
+static SPIRVType *deduceTypeFromResultRegister(MachineInstr *Use,
+                                               Register UseRegister,
+                                               SPIRVGlobalRegistry *GR,
+                                               MachineIRBuilder &MIB) {
   for (const MachineOperand &MO : Use->defs()) {
     if (!MO.isReg())
       continue;
@@ -159,16 +161,43 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
   MachineRegisterInfo &MRI = MF.getRegInfo();
   for (MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
     SPIRVType *ResType = nullptr;
+    LLVM_DEBUG(dbgs() << "Looking at use " << Use);
     switch (Use.getOpcode()) {
     case TargetOpcode::G_BUILD_VECTOR:
     case TargetOpcode::G_EXTRACT_VECTOR_ELT:
     case TargetOpcode::G_UNMERGE_VALUES:
-      LLVM_DEBUG(dbgs() << "Looking at use " << Use << "\n");
-      ResType = deduceTypeForResultRegister(&Use, Reg, GR, MIB);
+    case TargetOpcode::G_ADD:
+    case TargetOpcode::G_SUB:
+    case TargetOpcode::G_MUL:
+    case TargetOpcode::G_SDIV:
+    case TargetOpcode::G_UDIV:
+    case TargetOpcode::G_SREM:
+    case TargetOpcode::G_UREM:
+    case TargetOpcode::G_FADD:
+    case TargetOpcode::G_FSUB:
+    case TargetOpcode::G_FMUL:
+    case TargetOpcode::G_FDIV:
+    case TargetOpcode::G_FREM:
+    case TargetOpcode::G_FMA:
+      ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
+      break;
+    case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
+    case TargetOpcode::G_INTRINSIC: {
+      auto IntrinsicID = cast<GIntrinsic>(Use).getIntrinsicID();
+      if (IntrinsicID == Intrinsic::spv_insertelt) {
+        if (Reg == Use.getOperand(2).getReg())
+          ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
+      } else if (IntrinsicID == Intrinsic::spv_extractelt) {
+        if (Reg == Use.getOperand(2).getReg())
+          ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
+      }
       break;
     }
-    if (ResType)
+    }
+    if (ResType) {
+      LLVM_DEBUG(dbgs() << "Deduced type from use " << *ResType);
       return ResType;
+    }
   }
   return nullptr;
 }
diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
new file mode 100644
index 0000000000000..468d3ded4c306
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
@@ -0,0 +1,194 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion"
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#v4f64:]] = OpTypeVector %[[#double]] 4
+; CHECK-DAG: %[[#v2i32:]] = OpTypeVector %[[#int]] 2
+; CHECK-DAG: %[[#ptr_private_v4i32:]] = OpTypePointer Private %[[#v4i32]]
+; CHECK-DAG: %[[#ptr_private_v4f64:]] = OpTypePointer Private %[[#v4f64]]
+; CHECK-DAG: %[[#global_double:]] = OpVariable %[[#ptr_private_v4f64]] Private
+; CHECK-DAG: %[[#C15:]] = OpConstant %[[#int]] 15{{$}}
+; CHECK-DAG: %[[#C14:]] = OpConstant %[[#int]] 14{{$}}
+; CHECK-DAG: %[[#C13:]] = OpConstant %[[#int]] 13{{$}}
+; CHECK-DAG: %[[#C12:]] = OpConstant %[[#int]] 12{{$}}
+; CHECK-DAG: %[[#C11:]] = OpConstant %[[#int]] 11{{$}}
+; CHECK-DAG: %[[#C10:]] = OpConstant %[[#int]] 10{{$}}
+; CHECK-DAG: %[[#C9:]] = OpConstant %[[#int]] 9{{$}}
+; CHECK-DAG: %[[#C8:]] = OpConstant %[[#int]] 8{{$}}
+; CHECK-DAG: %[[#C7:]] = OpConstant %[[#int]] 7{{$}}
+; CHECK-DAG: %[[#C6:]] = OpConstant %[[#int]] 6{{$}}
+; CHECK-DAG: %[[#C5:]] = OpConstant %[[#int]] 5{{$}}
+; CHECK-DAG: %[[#C4:]] = OpConstant %[[#int]] 4{{$}}
+; CHECK-DAG: %[[#C3:]] = OpConstant %[[#int]] 3{{$}}
+; CHECK-DAG: %[[#C2:]] = OpConstant %[[#int]] 2{{$}}
+; CHECK-DAG: %[[#C1:]] = OpConstant %[[#int]] 1{{$}}
+; CHECK-DAG: %[[#C0:]] = OpConstant %[[#int]] 0{{$}}
+
+ at G_16 = internal addrspace(10) global [16 x i32] zeroinitializer
+ at G_4_double = internal addrspace(10) global <4 x double> zeroinitializer
+ at G_4_int = internal addrspace(10) global <4 x i32> zeroinitializer
+
+
+; This is the way matrices will be represented in HLSL. The memory type will be
+; an array, but it will be loaded as a vector.
+define spir_func void @test_load_store_global() {
+entry:
+; CHECK-DAG: %[[#PTR0:]] = OpAccessChain %[[#ptr_int:]] %[[#G16:]] %[[#C0]]
+; CHECK-DAG: %[[#VAL0:]] = OpLoad %[[#int]] %[[#PTR0]] Aligned 4
+; CHECK-DAG: %[[#PTR1:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C1]]
+; CHECK-DAG: %[[#VAL1:]] = OpLoad %[[#int]] %[[#PTR1]] Aligned 4
+; CHECK-DAG: %[[#PTR2:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C2]]
+; CHECK-DAG: %[[#VAL2:]] = OpLoad %[[#int]] %[[#PTR2]] Aligned 4
+; CHECK-DAG: %[[#PTR3:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C3]]
+; CHECK-DAG: %[[#VAL3:]] = OpLoad %[[#int]] %[[#PTR3]] Aligned 4
+; CHECK-DAG: %[[#PTR4:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C4]]
+; CHECK-DAG: %[[#VAL4:]] = OpLoad %[[#int]] %[[#PTR4]] Aligned 4
+; CHECK-DAG: %[[#PTR5:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C5]]
+; CHECK-DAG: %[[#VAL5:]] = OpLoad %[[#int]] %[[#PTR5]] Aligned 4
+; CHECK-DAG: %[[#PTR6:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C6]]
+; CHECK-DAG: %[[#VAL6:]] = OpLoad %[[#int]] %[[#PTR6]] Aligned 4
+; CHECK-DAG: %[[#PTR7:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C7]]
+; CHECK-DAG: %[[#VAL7:]] = OpLoad %[[#int]] %[[#PTR7]] Aligned 4
+; CHECK-DAG: %[[#PTR8:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C8]]
+; CHECK-DAG: %[[#VAL8:]] = OpLoad %[[#int]] %[[#PTR8]] Aligned 4
+; CHECK-DAG: %[[#PTR9:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C9]]
+; CHECK-DAG: %[[#VAL9:]] = OpLoad %[[#int]] %[[#PTR9]] Aligned 4
+; CHECK-DAG: %[[#PTR10:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C10]]
+; CHECK-DAG: %[[#VAL10:]] = OpLoad %[[#int]] %[[#PTR10]] Aligned 4
+; CHECK-DAG: %[[#PTR11:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C11]]
+; CHECK-DAG: %[[#VAL11:]] = OpLoad %[[#int]] %[[#PTR11]] Aligned 4
+; CHECK-DAG: %[[#PTR12:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C12]]
+; CHECK-DAG: %[[#VAL12:]] = OpLoad %[[#int]] %[[#PTR12]] Aligned 4
+; CHECK-DAG: %[[#PTR13:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C13]]
+; CHECK-DAG: %[[#VAL13:]] = OpLoad %[[#int]] %[[#PTR13]] Aligned 4
+; CHECK-DAG: %[[#PTR14:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C14]]
+; CHECK-DAG: %[[#VAL14:]] = OpLoad %[[#int]] %[[#PTR14]] Aligned 4
+; CHECK-DAG: %[[#PTR15:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C15]]
+; CHECK-DAG: %[[#VAL15:]] = OpLoad %[[#int]] %[[#PTR15]] Aligned 4
+; CHECK-DAG: %[[#INS0:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL0]] %[[#UNDEF:]] 0
+; CHECK-DAG: %[[#INS1:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL1]] %[[#INS0]] 1
+; CHECK-DAG: %[[#INS2:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL2]] %[[#INS1]] 2
+; CHECK-DAG: %[[#INS3:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL3]] %[[#INS2]] 3
+; CHECK-DAG: %[[#INS4:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL4]] %[[#UNDEF]] 0
+; CHECK-DAG: %[[#INS5:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL5]] %[[#INS4]] 1
+; CHECK-DAG: %[[#INS6:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL6]] %[[#INS5]] 2
+; CHECK-DAG: %[[#INS7:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL7]] %[[#INS6]] 3
+; CHECK-DAG: %[[#INS8:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL8]] %[[#UNDEF]] 0
+; CHECK-DAG: %[[#INS9:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL9]] %[[#INS8]] 1
+; CHECK-DAG: %[[#INS10:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL10]] %[[#INS9]] 2
+; CHECK-DAG: %[[#INS11:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL11]] %[[#INS10]] 3
+; CHECK-DAG: %[[#INS12:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL12]] %[[#UNDEF]] 0
+; CHECK-DAG: %[[#INS13:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL13]] %[[#INS12]] 1
+; CHECK-DAG: %[[#INS14:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL14]] %[[#INS13]] 2
+; CHECK-DAG: %[[#INS15:]] = OpCompositeInsert %[[#v4i32]] %[[#VAL15]] %[[#INS14]] 3
+  %0 = load <16 x i32>, ptr addrspace(10) @G_16, align 64
+ 
+; CHECK-DAG: %[[#PTR0_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C0]]
+; CHECK-DAG: %[[#VAL0_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 0
+; CHECK-DAG: OpStore %[[#PTR0_S]] %[[#VAL0_S]] Aligned 64
+; CHECK-DAG: %[[#PTR1_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C1]]
+; CHECK-DAG: %[[#VAL1_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 1
+; CHECK-DAG: OpStore %[[#PTR1_S]] %[[#VAL1_S]] Aligned 64
+; CHECK-DAG: %[[#PTR2_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C2]]
+; CHECK-DAG: %[[#VAL2_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 2
+; CHECK-DAG: OpStore %[[#PTR2_S]] %[[#VAL2_S]] Aligned 64
+; CHECK-DAG: %[[#PTR3_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C3]]
+; CHECK-DAG: %[[#VAL3_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 3
+; CHECK-DAG: OpStore %[[#PTR3_S]] %[[#VAL3_S]] Aligned 64
+; CHECK-DAG: %[[#PTR4_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C4]]
+; CHECK-DAG: %[[#VAL4_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 0
+; CHECK-DAG: OpStore %[[#PTR4_S]] %[[#VAL4_S]] Aligned 64
+; CHECK-DAG: %[[#PTR5_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C5]]
+; CHECK-DAG: %[[#VAL5_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 1
+; CHECK-DAG: OpStore %[[#PTR5_S]] %[[#VAL5_S]] Aligned 64
+; CHECK-DAG: %[[#PTR6_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C6]]
+; CHECK-DAG: %[[#VAL6_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 2
+; CHECK-DAG: OpStore %[[#PTR6_S]] %[[#VAL6_S]] Aligned 64
+; CHECK-DAG: %[[#PTR7_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C7]]
+; CHECK-DAG: %[[#VAL7_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 3
+; CHECK-DAG: OpStore %[[#PTR7_S]] %[[#VAL7_S]] Aligned 64
+; CHECK-DAG: %[[#PTR8_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C8]]
+; CHECK-DAG: %[[#VAL8_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 0
+; CHECK-DAG: OpStore %[[#PTR8_S]] %[[#VAL8_S]] Aligned 64
+; CHECK-DAG: %[[#PTR9_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C9]]
+; CHECK-DAG: %[[#VAL9_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 1
+; CHECK-DAG: OpStore %[[#PTR9_S]] %[[#VAL9_S]] Aligned 64
+; CHECK-DAG: %[[#PTR10_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C10]]
+; CHECK-DAG: %[[#VAL10_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 2
+; CHECK-DAG: OpStore %[[#PTR10_S]] %[[#VAL10_S]] Aligned 64
+; CHECK-DAG: %[[#PTR11_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C11]]
+; CHECK-DAG: %[[#VAL11_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 3
+; CHECK-DAG: OpStore %[[#PTR11_S]] %[[#VAL11_S]] Aligned 64
+; CHECK-DAG: %[[#PTR12_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C12]]
+; CHECK-DAG: %[[#VAL12_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 0
+; CHECK-DAG: OpStore %[[#PTR12_S]] %[[#VAL12_S]] Aligned 64
+; CHECK-DAG: %[[#PTR13_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C13]]
+; CHECK-DAG: %[[#VAL13_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 1
+; CHECK-DAG: OpStore %[[#PTR13_S]] %[[#VAL13_S]] Aligned 64
+; CHECK-DAG: %[[#PTR14_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C14]]
+; CHECK-DAG: %[[#VAL14_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 2
+; CHECK-DAG: OpStore %[[#PTR14_S]] %[[#VAL14_S]] Aligned 64
+; CHECK-DAG: %[[#PTR15_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C15]]
+; CHECK-DAG: %[[#VAL15_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 3
+; CHECK-DAG: OpStore %[[#PTR15_S]] %[[#VAL15_S]] Aligned 64
+  store <16 x i32> %0, ptr addrspace(10) @G_16, align 64
+  ret void
+}
+
+define spir_func void @test_int32_double_conversion() {
+; CHECK: %[[#test_int32_double_conversion]] = OpFunction
+entry:
+  ; CHECK: %[[#LOAD:]] = OpLoad %[[#v4f64]] %[[#global_double]]
+  ; CHECK: %[[#VEC_SHUF1:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 1
+  ; CHECK: %[[#VEC_SHUF2:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 2 3
+  ; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF1]]
+  ; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#v4i32]] %[[#VEC_SHUF2]]
+  %0 = load <8 x i32>, ptr addrspace(10) @G_4_double
+
+  ; CHECK: %[[#EXTRACT1:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 0
+  ; CHECK: %[[#EXTRACT2:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 2
+  ; CHECK: %[[#EXTRACT3:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 0
+  ; CHECK: %[[#EXTRACT4:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 2
+  ; CHECK: %[[#CONSTRUCT1:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT1]] %[[#EXTRACT2]] %[[#EXTRACT3]] %[[#EXTRACT4]]
+  %1 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  
+  ; CHECK: %[[#EXTRACT5:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 1
+  ; CHECK: %[[#EXTRACT6:]] = OpCompositeExtract %[[#int]] %[[#BITCAST1]] 3
+  ; CHECK: %[[#EXTRACT7:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 1
+  ; CHECK: %[[#EXTRACT8:]] = OpCompositeExtract %[[#int]] %[[#BITCAST2]] 3
+  ; CHECK: %[[#CONSTRUCT2:]] = OpCompositeConstruct %[[#v4i32]] %[[#EXTRACT5]] %[[#EXTRACT6]] %[[#EXTRACT7]] %[[#EXTRACT8]]
+  %2 = shufflevector <8 x i32> %0, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+
+  ; CHECK: %[[#EXTRACT9:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 0
+  ; CHECK: %[[#EXTRACT10:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 0
+  ; CHECK: %[[#EXTRACT11:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 1
+  ; CHECK: %[[#EXTRACT12:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 1
+  ; CHECK: %[[#EXTRACT13:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 2
+  ; CHECK: %[[#EXTRACT14:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 2
+  ; CHECK: %[[#EXTRACT15:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT1]] 3
+  ; CHECK: %[[#EXTRACT16:]] = OpCompositeExtract %[[#int]] %[[#CONSTRUCT2]] 3
+  ; CHECK: %[[#CONSTRUCT3:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT9]] %[[#EXTRACT10]]
+  ; CHECK: %[[#CONSTRUCT4:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT11]] %[[#EXTRACT12]]
+  ; CHECK: %[[#CONSTRUCT5:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT13]] %[[#EXTRACT14]]
+  ; CHECK: %[[#CONSTRUCT6:]] = OpCompositeConstruct %[[#v2i32]] %[[#EXTRACT15]] %[[#EXTRACT16]]
+  %3 = shufflevector <4 x i32> %1, <4 x i32> %2, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
+
+  ; CHECK: %[[#BITCAST3:]] = OpBitcast %[[#double]] %[[#CONSTRUCT3]]
+  ; CHECK: %[[#BITCAST4:]] = OpBitcast %[[#double]] %[[#CONSTRUCT4]]
+  ; CHECK: %[[#BITCAST5:]] = OpBitcast %[[#double]] %[[#CONSTRUCT5]]
+  ; CHECK: %[[#BITCAST6:]] = OpBitcast %[[#double]] %[[#CONSTRUCT6]]
+  ; CHECK: %[[#CONSTRUCT7:]] = OpCompositeConstruct %[[#v4f64]] %[[#BITCAST3]] %[[#BITCAST4]] %[[#BITCAST5]] %[[#BITCAST6]]
+  ; CHECK: OpStore %[[#global_double]] %[[#CONSTRUCT7]] Aligned 32
+  store <8 x i32> %3, ptr addrspace(10) @G_4_double
+  ret void
+}
+
+; Add a main function to make it a valid module for spirv-val
+define void @main() #1 {
+  ret void
+}
+
+attributes #1 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
new file mode 100644
index 0000000000000..126548a1e9ea2
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
@@ -0,0 +1,149 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#main:]] "main"
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#v4f32:]] = OpTypeVector %[[#float]] 4
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#c16:]] = OpConstant %[[#int]] 16
+; CHECK-DAG: %[[#v16f32:]] = OpTypeArray %[[#float]] %[[#c16]]
+; CHECK-DAG: %[[#v16i32:]] = OpTypeArray %[[#int]] %[[#c16]]
+; CHECK-DAG: %[[#ptr_ssbo_v16i32:]] = OpTypePointer Private %[[#v16i32]]
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+
+ at f1 = internal addrspace(10) global [4 x [16 x float] ] zeroinitializer
+ at f2 = internal addrspace(10) global [4 x [16 x float] ] zeroinitializer
+ at i1 = internal addrspace(10) global [4 x [16 x i32] ] zeroinitializer
+ at i2 = internal addrspace(10) global [4 x [16 x i32] ] zeroinitializer
+
+define void @main() local_unnamed_addr #0 {
+; CHECK: %[[#main]] = OpFunction
+entry:
+  %2 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 0
+  %3 = load <16 x float>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 1
+  %5 = load <16 x float>, ptr addrspace(10) %4, align 4
+  %6 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 2
+  %7 = load <16 x float>, ptr addrspace(10) %6, align 4
+  %8 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 3
+  %9 = load <16 x float>, ptr addrspace(10) %8, align 4
+  
+  ; We expect the large vectors to be split into size 4, and the operations performed on them.
+  ; CHECK: OpFMul %[[#v4f32]]
+  ; CHECK: OpFMul %[[#v4f32]]
+  ; CHECK: OpFMul %[[#v4f32]]
+  ; CHECK: OpFMul %[[#v4f32]]
+  %10 = fmul reassoc nnan ninf nsz arcp afn <16 x float> %3, splat (float 3.000000e+00)
+
+  ; CHECK: OpFAdd %[[#v4f32]]
+  ; CHECK: OpFAdd %[[#v4f32]]
+  ; CHECK: OpFAdd %[[#v4f32]]
+  ; CHECK: OpFAdd %[[#v4f32]]
+  %11 = fadd reassoc nnan ninf nsz arcp afn <16 x float> %10, %5
+
+  ; CHECK: OpFAdd %[[#v4f32]]
+  ; CHECK: OpFAdd %[[#v4f32]]
+  ; CHECK: OpFAdd %[[#v4f32]]
+  ; CHECK: OpFAdd %[[#v4f32]]
+  %12 = fadd reassoc nnan ninf nsz arcp afn <16 x float> %11, %7
+
+  ; CHECK: OpFSub %[[#v4f32]]
+  ; CHECK: OpFSub %[[#v4f32]]
+  ; CHECK: OpFSub %[[#v4f32]]
+  ; CHECK: OpFSub %[[#v4f32]]
+  %13 = fsub reassoc nnan ninf nsz arcp afn <16 x float> %12, %9
+
+  %14 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
+  store <16 x float> %13, ptr addrspace(10) %14, align 4
+  ret void
+}
+
+; Test integer vector arithmetic operations
+define void @test_int_vector_arithmetic() local_unnamed_addr #0 {
+; CHECK: OpFunction
+entry:
+  %2 = getelementptr [4 x [16 x i32] ], ptr addrspace(10) @i1, i32 0, i32 0
+  %3 = load <16 x i32>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [16 x i32] ], ptr addrspace(10) @i1, i32 0, i32 1
+  %5 = load <16 x i32>, ptr addrspace(10) %4, align 4
+
+  ; CHECK: OpIAdd %[[#v4i32]]
+  ; CHECK: OpIAdd %[[#v4i32]]
+  ; CHECK: OpIAdd %[[#v4i32]]
+  ; CHECK: OpIAdd %[[#v4i32]]
+  %6 = add <16 x i32> %3, %5
+
+  ; CHECK: OpISub %[[#v4i32]]
+  ; CHECK: OpISub %[[#v4i32]]
+  ; CHECK: OpISub %[[#v4i32]]
+  ; CHECK: OpISub %[[#v4i32]]
+  %7 = sub <16 x i32> %6, %5
+
+  ; CHECK: OpIMul %[[#v4i32]]
+  ; CHECK: OpIMul %[[#v4i32]]
+  ; CHECK: OpIMul %[[#v4i32]]
+  ; CHECK: OpIMul %[[#v4i32]]
+  %8 = mul <16 x i32> %7, splat (i32 2)
+
+  ; CHECK: OpSDiv %[[#v4i32]]
+  ; CHECK: OpSDiv %[[#v4i32]]
+  ; CHECK: OpSDiv %[[#v4i32]]
+  ; CHECK: OpSDiv %[[#v4i32]]
+  %9 = sdiv <16 x i32> %8, splat (i32 2)
+
+  ; CHECK: OpUDiv %[[#v4i32]]
+  ; CHECK: OpUDiv %[[#v4i32]]
+  ; CHECK: OpUDiv %[[#v4i32]]
+  ; CHECK: OpUDiv %[[#v4i32]]
+  %10 = udiv <16 x i32> %9, splat (i32 1)
+
+  ; CHECK: OpSRem %[[#v4i32]]
+  ; CHECK: OpSRem %[[#v4i32]]
+  ; CHECK: OpSRem %[[#v4i32]]
+  ; CHECK: OpSRem %[[#v4i32]]
+  %11 = srem <16 x i32> %10, splat (i32 3)
+
+  ; CHECK: OpUMod %[[#v4i32]]
+  ; CHECK: OpUMod %[[#v4i32]]
+  ; CHECK: OpUMod %[[#v4i32]]
+  ; CHECK: OpUMod %[[#v4i32]]
+  %12 = urem <16 x i32> %11, splat (i32 3)
+
+  %13 = getelementptr [4 x [16 x i32] ], ptr addrspace(10) @i2, i32 0, i32 0
+  store <16 x i32> %12, ptr addrspace(10) %13, align 4
+  ret void
+}
+
+; Test remaining float vector arithmetic operations
+define void @test_float_vector_arithmetic_continued() local_unnamed_addr #0 {
+; CHECK: OpFunction
+entry:
+  %2 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 0
+  %3 = load <16 x float>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 1
+  %5 = load <16 x float>, ptr addrspace(10) %4, align 4
+
+  ; CHECK: OpFDiv %[[#v4f32]]
+  ; CHECK: OpFDiv %[[#v4f32]]
+  ; CHECK: OpFDiv %[[#v4f32]]
+  ; CHECK: OpFDiv %[[#v4f32]]
+  %6 = fdiv reassoc nnan ninf nsz arcp afn <16 x float> %3, splat (float 2.000000e+00)
+
+  ; CHECK: OpFRem %[[#v4f32]]
+  ; CHECK: OpFRem %[[#v4f32]]
+  ; CHECK: OpFRem %[[#v4f32]]
+  ; CHECK: OpFRem %[[#v4f32]]
+  %7 = frem reassoc nnan ninf nsz arcp afn <16 x float> %6, splat (float 3.000000e+00)
+
+  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
+  %8 = call reassoc nnan ninf nsz arcp afn <16 x float> @llvm.fma.v16f32(<16 x float> %5, <16 x float> %6, <16 x float> %7)
+
+  %9 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
+  store <16 x float> %8, ptr addrspace(10) %9, align 4
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
\ No newline at end of file

>From 92f96f7347f880cf2c803d3c39812667277fec17 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 8 Dec 2025 15:44:01 -0500
Subject: [PATCH 2/8] Add s64 to all scalars.

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 4d83649c0f84f..a3ad5d7fd40c0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -113,7 +113,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                            v3s1, v3s8, v3s16, v3s32, v3s64,
                            v4s1, v4s8, v4s16, v4s32, v4s64};
 
-  auto allScalars = {s1, s8, s16, s32};
+  auto allScalars = {s1, s8, s16, s32, s64};
 
   auto allScalarsAndVectors = {
       s1,   s8,   s16,   s32,   s64,   v2s1,  v2s8,  v2s16,  v2s32,  v2s64,

>From c6c4a277ec1c94f67cd25594970d17036fedbdf0 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 8 Dec 2025 16:08:01 -0500
Subject: [PATCH 3/8] Fix alignment in stores generated in legalize pointer
 cast.

---
 .../Target/SPIRV/SPIRVLegalizePointerCast.cpp |  6 +++-
 .../SPIRV/legalization/load-store-global.ll   | 30 +++++++++----------
 2 files changed, 20 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
index 81c7596530ee2..a794f3e9c5363 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
@@ -168,6 +168,9 @@ class SPIRVLegalizePointerCast : public FunctionPass {
     assert(VecTy->getElementType() == ArrTy->getElementType() &&
            "Element types of array and vector must be the same.");
 
+    const DataLayout &DL = B.GetInsertBlock()->getModule()->getDataLayout();
+    uint64_t ElemSize = DL.getTypeAllocSize(ArrTy->getElementType());
+
     for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
       // Create a GEP to access the i-th element of the array.
       SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
@@ -190,7 +193,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
       buildAssignType(B, VecTy->getElementType(), Element);
 
       Types = {Element->getType(), ElementPtr->getType()};
-      Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
+      Align NewAlign = commonAlignment(Alignment, i * ElemSize);
+      Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(NewAlign.value())};
       B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
     }
   }
diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
index 468d3ded4c306..39adba954568e 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
@@ -91,49 +91,49 @@ entry:
 ; CHECK-DAG: OpStore %[[#PTR0_S]] %[[#VAL0_S]] Aligned 64
 ; CHECK-DAG: %[[#PTR1_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C1]]
 ; CHECK-DAG: %[[#VAL1_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 1
-; CHECK-DAG: OpStore %[[#PTR1_S]] %[[#VAL1_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR1_S]] %[[#VAL1_S]] Aligned 4
 ; CHECK-DAG: %[[#PTR2_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C2]]
 ; CHECK-DAG: %[[#VAL2_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 2
-; CHECK-DAG: OpStore %[[#PTR2_S]] %[[#VAL2_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR2_S]] %[[#VAL2_S]] Aligned 8
 ; CHECK-DAG: %[[#PTR3_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C3]]
 ; CHECK-DAG: %[[#VAL3_S:]] = OpCompositeExtract %[[#int]] %[[#INS3]] 3
-; CHECK-DAG: OpStore %[[#PTR3_S]] %[[#VAL3_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR3_S]] %[[#VAL3_S]] Aligned 4
 ; CHECK-DAG: %[[#PTR4_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C4]]
 ; CHECK-DAG: %[[#VAL4_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 0
-; CHECK-DAG: OpStore %[[#PTR4_S]] %[[#VAL4_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR4_S]] %[[#VAL4_S]] Aligned 16
 ; CHECK-DAG: %[[#PTR5_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C5]]
 ; CHECK-DAG: %[[#VAL5_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 1
-; CHECK-DAG: OpStore %[[#PTR5_S]] %[[#VAL5_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR5_S]] %[[#VAL5_S]] Aligned 4
 ; CHECK-DAG: %[[#PTR6_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C6]]
 ; CHECK-DAG: %[[#VAL6_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 2
-; CHECK-DAG: OpStore %[[#PTR6_S]] %[[#VAL6_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR6_S]] %[[#VAL6_S]] Aligned 8
 ; CHECK-DAG: %[[#PTR7_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C7]]
 ; CHECK-DAG: %[[#VAL7_S:]] = OpCompositeExtract %[[#int]] %[[#INS7]] 3
-; CHECK-DAG: OpStore %[[#PTR7_S]] %[[#VAL7_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR7_S]] %[[#VAL7_S]] Aligned 4
 ; CHECK-DAG: %[[#PTR8_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C8]]
 ; CHECK-DAG: %[[#VAL8_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 0
-; CHECK-DAG: OpStore %[[#PTR8_S]] %[[#VAL8_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR8_S]] %[[#VAL8_S]] Aligned 32
 ; CHECK-DAG: %[[#PTR9_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C9]]
 ; CHECK-DAG: %[[#VAL9_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 1
-; CHECK-DAG: OpStore %[[#PTR9_S]] %[[#VAL9_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR9_S]] %[[#VAL9_S]] Aligned 4
 ; CHECK-DAG: %[[#PTR10_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C10]]
 ; CHECK-DAG: %[[#VAL10_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 2
-; CHECK-DAG: OpStore %[[#PTR10_S]] %[[#VAL10_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR10_S]] %[[#VAL10_S]] Aligned 8
 ; CHECK-DAG: %[[#PTR11_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C11]]
 ; CHECK-DAG: %[[#VAL11_S:]] = OpCompositeExtract %[[#int]] %[[#INS11]] 3
-; CHECK-DAG: OpStore %[[#PTR11_S]] %[[#VAL11_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR11_S]] %[[#VAL11_S]] Aligned 4
 ; CHECK-DAG: %[[#PTR12_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C12]]
 ; CHECK-DAG: %[[#VAL12_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 0
-; CHECK-DAG: OpStore %[[#PTR12_S]] %[[#VAL12_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR12_S]] %[[#VAL12_S]] Aligned 16
 ; CHECK-DAG: %[[#PTR13_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C13]]
 ; CHECK-DAG: %[[#VAL13_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 1
-; CHECK-DAG: OpStore %[[#PTR13_S]] %[[#VAL13_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR13_S]] %[[#VAL13_S]] Aligned 4
 ; CHECK-DAG: %[[#PTR14_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C14]]
 ; CHECK-DAG: %[[#VAL14_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 2
-; CHECK-DAG: OpStore %[[#PTR14_S]] %[[#VAL14_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR14_S]] %[[#VAL14_S]] Aligned 8
 ; CHECK-DAG: %[[#PTR15_S:]] = OpAccessChain %[[#ptr_int]] %[[#G16]] %[[#C15]]
 ; CHECK-DAG: %[[#VAL15_S:]] = OpCompositeExtract %[[#int]] %[[#INS15]] 3
-; CHECK-DAG: OpStore %[[#PTR15_S]] %[[#VAL15_S]] Aligned 64
+; CHECK-DAG: OpStore %[[#PTR15_S]] %[[#VAL15_S]] Aligned 4
   store <16 x i32> %0, ptr addrspace(10) @G_16, align 64
   ret void
 }

>From 0760562bd5db87b599f1d1a5248174663e9ee694 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 9 Dec 2025 14:36:42 -0500
Subject: [PATCH 4/8] Set all def to default type in post legalization.

---
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  | 31 +++++++++++--------
 .../SPIRV/legalization/load-store-global.ll   | 28 +++++++++++++++--
 2 files changed, 44 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index d91016a38539b..1170a7a2f3270 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -325,20 +325,25 @@ static void registerSpirvTypeForNewInstructions(MachineFunction &MF,
 
   for (auto *I : Worklist) {
     MachineIRBuilder MIB(*I);
-    Register ResVReg = I->getOperand(0).getReg();
-    const LLT &ResLLT = MRI.getType(ResVReg);
-    SPIRVType *ResType = nullptr;
-    if (ResLLT.isVector()) {
-      SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
-          ResLLT.getElementType().getSizeInBits(), MIB);
-      ResType = GR->getOrCreateSPIRVVectorType(
-          CompType, ResLLT.getNumElements(), MIB, false);
-    } else {
-      ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+    for (unsigned Idx = 0; Idx < I->getNumDefs(); ++Idx) {
+      Register ResVReg = I->getOperand(Idx).getReg();
+      if (GR->getSPIRVTypeForVReg(ResVReg))
+        continue;
+      const LLT &ResLLT = MRI.getType(ResVReg);
+      SPIRVType *ResType = nullptr;
+      if (ResLLT.isVector()) {
+        SPIRVType *CompType = GR->getOrCreateSPIRVIntegerType(
+            ResLLT.getElementType().getSizeInBits(), MIB);
+        ResType = GR->getOrCreateSPIRVVectorType(
+            CompType, ResLLT.getNumElements(), MIB, false);
+      } else {
+        ResType = GR->getOrCreateSPIRVIntegerType(ResLLT.getSizeInBits(), MIB);
+      }
+      LLVM_DEBUG(dbgs() << "Could not determine type for " << ResVReg
+                        << ", defaulting to " << *ResType << "\n");
+
+      setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
     }
-    LLVM_DEBUG(dbgs() << "Could not determine type for " << *I
-                      << ", defaulting to " << *ResType << "\n");
-    setRegClassType(ResVReg, ResType, GR, &MRI, MF, true);
   }
 }
 
diff --git a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
index 39adba954568e..19b39ff59809a 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/load-store-global.ll
@@ -1,7 +1,6 @@
 ; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
 
-; CHECK-DAG: OpName %[[#test_int32_double_conversion:]] "test_int32_double_conversion"
 ; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
 ; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
 ; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
@@ -139,7 +138,7 @@ entry:
 }
 
 define spir_func void @test_int32_double_conversion() {
-; CHECK: %[[#test_int32_double_conversion]] = OpFunction
+; CHECK: OpFunction
 entry:
   ; CHECK: %[[#LOAD:]] = OpLoad %[[#v4f64]] %[[#global_double]]
   ; CHECK: %[[#VEC_SHUF1:]] = OpVectorShuffle %{{[a-zA-Z0-9_]+}} %[[#LOAD]] %{{[a-zA-Z0-9_]+}} 0 1
@@ -186,6 +185,31 @@ entry:
   ret void
 }
 
+; CHECK: OpFunction
+define spir_func void @test_double_to_int_implicit_conversion() {
+entry:
+
+; CHECK: %[[#LOAD_V4F64:]] = OpLoad %[[#V4F64_TYPE:]] %[[#GLOBAL_DOUBLE_VAR:]] Aligned 32
+; CHECK: %[[#VEC_SHUF_01:]] = OpVectorShuffle %[[#V2F64_TYPE:]] %[[#LOAD_V4F64]] %[[#UNDEF_V2F64:]] 0 1
+; CHECK: %[[#VEC_SHUF_23:]] = OpVectorShuffle %[[#V2F64_TYPE:]] %[[#LOAD_V4F64]] %[[#UNDEF_V2F64]] 2 3
+; CHECK: %[[#BITCAST_V4I32_01:]] = OpBitcast %[[#V4I32_TYPE:]] %[[#VEC_SHUF_01]]
+; CHECK: %[[#BITCAST_V4I32_23:]] = OpBitcast %[[#V4I32_TYPE]] %[[#VEC_SHUF_23]]
+  %0 = load <8 x i32>, ptr addrspace(10) @G_4_double, align 64
+
+; CHECK: %[[#VEC_SHUF_0_0:]] = OpVectorShuffle %[[#V2I32_TYPE:]] %[[#BITCAST_V4I32_01]] %[[#UNDEF_V2I32:]] 0 1
+; CHECK: %[[#VEC_SHUF_0_1:]] = OpVectorShuffle %[[#V2I32_TYPE]] %[[#BITCAST_V4I32_01]] %[[#UNDEF_V2I32]] 2 3
+; CHECK: %[[#VEC_SHUF_1_0:]] = OpVectorShuffle %[[#V2I32_TYPE]] %[[#BITCAST_V4I32_23]] %[[#UNDEF_V2I32]] 0 1
+; CHECK: %[[#VEC_SHUF_1_1:]] = OpVectorShuffle %[[#V2I32_TYPE]] %[[#BITCAST_V4I32_23]] %[[#UNDEF_V2I32]] 2 3
+; CHECK: %[[#BITCAST_DOUBLE_0_0:]] = OpBitcast %[[#DOUBLE_TYPE:]] %[[#VEC_SHUF_0_0]]
+; CHECK: %[[#BITCAST_DOUBLE_0_1:]] = OpBitcast %[[#DOUBLE_TYPE]] %[[#VEC_SHUF_0_1]]
+; CHECK: %[[#BITCAST_DOUBLE_1_0:]] = OpBitcast %[[#DOUBLE_TYPE]] %[[#VEC_SHUF_1_0]]
+; CHECK: %[[#BITCAST_DOUBLE_1_1:]] = OpBitcast %[[#DOUBLE_TYPE]] %[[#VEC_SHUF_1_1]]
+; CHECK: %[[#COMPOSITE_CONSTRUCT:]] = OpCompositeConstruct %[[#V4F64_TYPE]] %[[#BITCAST_DOUBLE_0_0]] %[[#BITCAST_DOUBLE_0_1]] %[[#BITCAST_DOUBLE_1_0]] %[[#BITCAST_DOUBLE_1_1]]
+; CHECK: OpStore %[[#GLOBAL_DOUBLE_VAR]] %[[#COMPOSITE_CONSTRUCT]] Aligned 64
+  store <8 x i32> %0, ptr addrspace(10) @G_4_double, align 64
+  ret void
+}
+
 ; Add a main function to make it a valid module for spirv-val
 define void @main() #1 {
   ret void

>From 740bf6ece1408c7b12a0a313180d6d9bf83ae64e Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 10 Dec 2025 14:36:38 -0500
Subject: [PATCH 5/8] Update tests, and add tests for 6-element vectors

---
 .../SPIRV/legalization/vector-arithmetic-6.ll | 156 +++++++++++++
 .../SPIRV/legalization/vector-arithmetic.ll   | 215 +++++++++++++-----
 2 files changed, 310 insertions(+), 61 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll

diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
new file mode 100644
index 0000000000000..4f889a3043efc
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
@@ -0,0 +1,156 @@
+; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: OpName %[[#main:]] "main"
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#v4f32:]] = OpTypeVector %[[#float]] 4
+; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#c6:]] = OpConstant %[[#int]] 6
+; CHECK-DAG: %[[#v6f32:]] = OpTypeArray %[[#float]] %[[#c6]]
+; CHECK-DAG: %[[#v6i32:]] = OpTypeArray %[[#int]] %[[#c6]]
+; CHECK-DAG: %[[#ptr_ssbo_v6i32:]] = OpTypePointer Private %[[#v6i32]]
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+
+ at f1 = internal addrspace(10) global [4 x [6 x float] ] zeroinitializer
+ at f2 = internal addrspace(10) global [4 x [6 x float] ] zeroinitializer
+ at i1 = internal addrspace(10) global [4 x [6 x i32] ] zeroinitializer
+ at i2 = internal addrspace(10) global [4 x [6 x i32] ] zeroinitializer
+
+define void @main() local_unnamed_addr #0 {
+; CHECK: %[[#main]] = OpFunction
+entry:
+  %2 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 0
+  %3 = load <6 x float>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 1
+  %5 = load <6 x float>, ptr addrspace(10) %4, align 4
+  %6 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 2
+  %7 = load <6 x float>, ptr addrspace(10) %6, align 4
+  %8 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 3
+  %9 = load <6 x float>, ptr addrspace(10) %8, align 4
+  
+  ; We expect the 6-element vectors to be widened to 8, then split into two vectors of size 4.
+  ; CHECK: %[[#Mul1:]] = OpFMul %[[#v4f32]]
+  ; CHECK: %[[#Mul2:]] = OpFMul %[[#v4f32]]
+  %10 = fmul reassoc nnan ninf nsz arcp afn <6 x float> %3, splat (float 3.000000e+00)
+
+  ; CHECK: %[[#Add1:]] = OpFAdd %[[#v4f32]] %[[#Mul1]]
+  ; CHECK: %[[#Add2:]] = OpFAdd %[[#v4f32]] %[[#Mul2]]
+  %11 = fadd reassoc nnan ninf nsz arcp afn <6 x float> %10, %5
+
+  ; CHECK: %[[#Sub1:]] = OpFSub %[[#v4f32]] %[[#Add1]]
+  ; CHECK: %[[#Sub2:]] = OpFSub %[[#v4f32]] %[[#Add2]]
+  %13 = fsub reassoc nnan ninf nsz arcp afn <6 x float> %11, %9
+
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  
+  %14 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
+  store <6 x float> %13, ptr addrspace(10) %14, align 4
+  ret void
+}
+
+; Test integer vector arithmetic operations
+define void @test_int_vector_arithmetic() local_unnamed_addr #0 {
+; CHECK: OpFunction
+entry:
+  %2 = getelementptr [4 x [6 x i32] ], ptr addrspace(10) @i1, i32 0, i32 0
+  %3 = load <6 x i32>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [6 x i32] ], ptr addrspace(10) @i1, i32 0, i32 1
+  %5 = load <6 x i32>, ptr addrspace(10) %4, align 4
+
+  ; CHECK: %[[#Add1:]] = OpIAdd %[[#v4i32]]
+  ; CHECK: %[[#Add2:]] = OpIAdd %[[#v4i32]]
+  %6 = add <6 x i32> %3, %5
+
+  ; CHECK: %[[#Sub1:]] = OpISub %[[#v4i32]] %[[#Add1]]
+  ; CHECK: %[[#Sub2:]] = OpISub %[[#v4i32]] %[[#Add2]]
+  %7 = sub <6 x i32> %6, %5
+
+  ; CHECK: %[[#Mul1:]] = OpIMul %[[#v4i32]] %[[#Sub1]]
+  ; CHECK: %[[#Mul2:]] = OpIMul %[[#v4i32]] %[[#Sub2]]
+  %8 = mul <6 x i32> %7, splat (i32 2)
+
+  ; CHECK: %[[#SDiv1:]] = OpSDiv %[[#v4i32]] %[[#Mul1]]
+  ; CHECK: %[[#SDiv2:]] = OpSDiv %[[#v4i32]] %[[#Mul2]]
+  %9 = sdiv <6 x i32> %8, splat (i32 2)
+
+  ; CHECK: %[[#UDiv1:]] = OpUDiv %[[#v4i32]] %[[#SDiv1]]
+  ; CHECK: %[[#UDiv2:]] = OpUDiv %[[#v4i32]] %[[#SDiv2]]
+  %10 = udiv <6 x i32> %9, splat (i32 1)
+
+  ; CHECK: %[[#SRem1:]] = OpSRem %[[#v4i32]] %[[#UDiv1]]
+  ; CHECK: %[[#SRem2:]] = OpSRem %[[#v4i32]] %[[#UDiv2]]
+  %11 = srem <6 x i32> %10, splat (i32 3)
+
+  ; CHECK: %[[#UMod1:]] = OpUMod %[[#v4i32]] %[[#SRem1]]
+  ; CHECK: %[[#UMod2:]] = OpUMod %[[#v4i32]] %[[#SRem2]]
+  %12 = urem <6 x i32> %11, splat (i32 3)
+
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+
+  %13 = getelementptr [4 x [6 x i32] ], ptr addrspace(10) @i2, i32 0, i32 0
+  store <6 x i32> %12, ptr addrspace(10) %13, align 4
+  ret void
+}
+
+; Test remaining float vector arithmetic operations
+define void @test_float_vector_arithmetic_continued() local_unnamed_addr #0 {
+; CHECK: OpFunction
+entry:
+  %2 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 0
+  %3 = load <6 x float>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 1
+  %5 = load <6 x float>, ptr addrspace(10) %4, align 4
+
+  ; CHECK: %[[#FDiv1:]] = OpFDiv %[[#v4f32]]
+  ; CHECK: %[[#FDiv2:]] = OpFDiv %[[#v4f32]]
+  %6 = fdiv reassoc nnan ninf nsz arcp afn <6 x float> %3, splat (float 2.000000e+00)
+
+  ; CHECK: %[[#FRem1:]] = OpFRem %[[#v4f32]] %[[#FDiv1]]
+  ; CHECK: %[[#FRem2:]] = OpFRem %[[#v4f32]] %[[#FDiv2]]
+  %7 = frem reassoc nnan ninf nsz arcp afn <6 x float> %6, splat (float 3.000000e+00)
+
+  ; CHECK: %[[#Fma1:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv1]] %[[#FRem1]]
+  ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv2]] %[[#FRem2]]
+  %8 = call reassoc nnan ninf nsz arcp afn <6 x float> @llvm.fma.v6f32(<6 x float> %5, <6 x float> %6, <6 x float> %7)
+
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+
+  %9 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
+  store <6 x float> %8, ptr addrspace(10) %9, align 4
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
index 126548a1e9ea2..e5476f5b15055 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
@@ -29,30 +29,57 @@ entry:
   %9 = load <16 x float>, ptr addrspace(10) %8, align 4
   
   ; We expect the large vectors to be split into size 4, and the operations performed on them.
-  ; CHECK: OpFMul %[[#v4f32]]
-  ; CHECK: OpFMul %[[#v4f32]]
-  ; CHECK: OpFMul %[[#v4f32]]
-  ; CHECK: OpFMul %[[#v4f32]]
+  ; CHECK: %[[#Mul1:]] = OpFMul %[[#v4f32]]
+  ; CHECK: %[[#Mul2:]] = OpFMul %[[#v4f32]]
+  ; CHECK: %[[#Mul3:]] = OpFMul %[[#v4f32]]
+  ; CHECK: %[[#Mul4:]] = OpFMul %[[#v4f32]]
   %10 = fmul reassoc nnan ninf nsz arcp afn <16 x float> %3, splat (float 3.000000e+00)
 
-  ; CHECK: OpFAdd %[[#v4f32]]
-  ; CHECK: OpFAdd %[[#v4f32]]
-  ; CHECK: OpFAdd %[[#v4f32]]
-  ; CHECK: OpFAdd %[[#v4f32]]
+  ; CHECK: %[[#Add1:]] = OpFAdd %[[#v4f32]] %[[#Mul1]]
+  ; CHECK: %[[#Add2:]] = OpFAdd %[[#v4f32]] %[[#Mul2]]
+  ; CHECK: %[[#Add3:]] = OpFAdd %[[#v4f32]] %[[#Mul3]]
+  ; CHECK: %[[#Add4:]] = OpFAdd %[[#v4f32]] %[[#Mul4]]
   %11 = fadd reassoc nnan ninf nsz arcp afn <16 x float> %10, %5
 
-  ; CHECK: OpFAdd %[[#v4f32]]
-  ; CHECK: OpFAdd %[[#v4f32]]
-  ; CHECK: OpFAdd %[[#v4f32]]
-  ; CHECK: OpFAdd %[[#v4f32]]
-  %12 = fadd reassoc nnan ninf nsz arcp afn <16 x float> %11, %7
-
-  ; CHECK: OpFSub %[[#v4f32]]
-  ; CHECK: OpFSub %[[#v4f32]]
-  ; CHECK: OpFSub %[[#v4f32]]
-  ; CHECK: OpFSub %[[#v4f32]]
-  %13 = fsub reassoc nnan ninf nsz arcp afn <16 x float> %12, %9
-
+  ; CHECK: %[[#Sub1:]] = OpFSub %[[#v4f32]] %[[#Add1]]
+  ; CHECK: %[[#Sub2:]] = OpFSub %[[#v4f32]] %[[#Add2]]
+  ; CHECK: %[[#Sub3:]] = OpFSub %[[#v4f32]] %[[#Add3]]
+  ; CHECK: %[[#Sub4:]] = OpFSub %[[#v4f32]] %[[#Add4]]
+  %13 = fsub reassoc nnan ninf nsz arcp afn <16 x float> %11, %9
+
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub2]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub3]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub3]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub3]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub3]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub4]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub4]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub4]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Sub4]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  
   %14 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
   store <16 x float> %13, ptr addrspace(10) %14, align 4
   ret void
@@ -67,48 +94,81 @@ entry:
   %4 = getelementptr [4 x [16 x i32] ], ptr addrspace(10) @i1, i32 0, i32 1
   %5 = load <16 x i32>, ptr addrspace(10) %4, align 4
 
-  ; CHECK: OpIAdd %[[#v4i32]]
-  ; CHECK: OpIAdd %[[#v4i32]]
-  ; CHECK: OpIAdd %[[#v4i32]]
-  ; CHECK: OpIAdd %[[#v4i32]]
+  ; CHECK: %[[#Add1:]] = OpIAdd %[[#v4i32]]
+  ; CHECK: %[[#Add2:]] = OpIAdd %[[#v4i32]]
+  ; CHECK: %[[#Add3:]] = OpIAdd %[[#v4i32]]
+  ; CHECK: %[[#Add4:]] = OpIAdd %[[#v4i32]]
   %6 = add <16 x i32> %3, %5
 
-  ; CHECK: OpISub %[[#v4i32]]
-  ; CHECK: OpISub %[[#v4i32]]
-  ; CHECK: OpISub %[[#v4i32]]
-  ; CHECK: OpISub %[[#v4i32]]
+  ; CHECK: %[[#Sub1:]] = OpISub %[[#v4i32]] %[[#Add1]]
+  ; CHECK: %[[#Sub2:]] = OpISub %[[#v4i32]] %[[#Add2]]
+  ; CHECK: %[[#Sub3:]] = OpISub %[[#v4i32]] %[[#Add3]]
+  ; CHECK: %[[#Sub4:]] = OpISub %[[#v4i32]] %[[#Add4]]
   %7 = sub <16 x i32> %6, %5
 
-  ; CHECK: OpIMul %[[#v4i32]]
-  ; CHECK: OpIMul %[[#v4i32]]
-  ; CHECK: OpIMul %[[#v4i32]]
-  ; CHECK: OpIMul %[[#v4i32]]
+  ; CHECK: %[[#Mul1:]] = OpIMul %[[#v4i32]] %[[#Sub1]]
+  ; CHECK: %[[#Mul2:]] = OpIMul %[[#v4i32]] %[[#Sub2]]
+  ; CHECK: %[[#Mul3:]] = OpIMul %[[#v4i32]] %[[#Sub3]]
+  ; CHECK: %[[#Mul4:]] = OpIMul %[[#v4i32]] %[[#Sub4]]
   %8 = mul <16 x i32> %7, splat (i32 2)
 
-  ; CHECK: OpSDiv %[[#v4i32]]
-  ; CHECK: OpSDiv %[[#v4i32]]
-  ; CHECK: OpSDiv %[[#v4i32]]
-  ; CHECK: OpSDiv %[[#v4i32]]
+  ; CHECK: %[[#SDiv1:]] = OpSDiv %[[#v4i32]] %[[#Mul1]]
+  ; CHECK: %[[#SDiv2:]] = OpSDiv %[[#v4i32]] %[[#Mul2]]
+  ; CHECK: %[[#SDiv3:]] = OpSDiv %[[#v4i32]] %[[#Mul3]]
+  ; CHECK: %[[#SDiv4:]] = OpSDiv %[[#v4i32]] %[[#Mul4]]
   %9 = sdiv <16 x i32> %8, splat (i32 2)
 
-  ; CHECK: OpUDiv %[[#v4i32]]
-  ; CHECK: OpUDiv %[[#v4i32]]
-  ; CHECK: OpUDiv %[[#v4i32]]
-  ; CHECK: OpUDiv %[[#v4i32]]
+  ; CHECK: %[[#UDiv1:]] = OpUDiv %[[#v4i32]] %[[#SDiv1]]
+  ; CHECK: %[[#UDiv2:]] = OpUDiv %[[#v4i32]] %[[#SDiv2]]
+  ; CHECK: %[[#UDiv3:]] = OpUDiv %[[#v4i32]] %[[#SDiv3]]
+  ; CHECK: %[[#UDiv4:]] = OpUDiv %[[#v4i32]] %[[#SDiv4]]
   %10 = udiv <16 x i32> %9, splat (i32 1)
 
-  ; CHECK: OpSRem %[[#v4i32]]
-  ; CHECK: OpSRem %[[#v4i32]]
-  ; CHECK: OpSRem %[[#v4i32]]
-  ; CHECK: OpSRem %[[#v4i32]]
+  ; CHECK: %[[#SRem1:]] = OpSRem %[[#v4i32]] %[[#UDiv1]]
+  ; CHECK: %[[#SRem2:]] = OpSRem %[[#v4i32]] %[[#UDiv2]]
+  ; CHECK: %[[#SRem3:]] = OpSRem %[[#v4i32]] %[[#UDiv3]]
+  ; CHECK: %[[#SRem4:]] = OpSRem %[[#v4i32]] %[[#UDiv4]]
   %11 = srem <16 x i32> %10, splat (i32 3)
 
-  ; CHECK: OpUMod %[[#v4i32]]
-  ; CHECK: OpUMod %[[#v4i32]]
-  ; CHECK: OpUMod %[[#v4i32]]
-  ; CHECK: OpUMod %[[#v4i32]]
+  ; CHECK: %[[#UMod1:]] = OpUMod %[[#v4i32]] %[[#SRem1]]
+  ; CHECK: %[[#UMod2:]] = OpUMod %[[#v4i32]] %[[#SRem2]]
+  ; CHECK: %[[#UMod3:]] = OpUMod %[[#v4i32]] %[[#SRem3]]
+  ; CHECK: %[[#UMod4:]] = OpUMod %[[#v4i32]] %[[#SRem4]]
   %12 = urem <16 x i32> %11, splat (i32 3)
 
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod3]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod3]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod3]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod3]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod4]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod4]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod4]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod4]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+
   %13 = getelementptr [4 x [16 x i32] ], ptr addrspace(10) @i2, i32 0, i32 0
   store <16 x i32> %12, ptr addrspace(10) %13, align 4
   ret void
@@ -123,27 +183,60 @@ entry:
   %4 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 1
   %5 = load <16 x float>, ptr addrspace(10) %4, align 4
 
-  ; CHECK: OpFDiv %[[#v4f32]]
-  ; CHECK: OpFDiv %[[#v4f32]]
-  ; CHECK: OpFDiv %[[#v4f32]]
-  ; CHECK: OpFDiv %[[#v4f32]]
+  ; CHECK: %[[#FDiv1:]] = OpFDiv %[[#v4f32]]
+  ; CHECK: %[[#FDiv2:]] = OpFDiv %[[#v4f32]]
+  ; CHECK: %[[#FDiv3:]] = OpFDiv %[[#v4f32]]
+  ; CHECK: %[[#FDiv4:]] = OpFDiv %[[#v4f32]]
   %6 = fdiv reassoc nnan ninf nsz arcp afn <16 x float> %3, splat (float 2.000000e+00)
 
-  ; CHECK: OpFRem %[[#v4f32]]
-  ; CHECK: OpFRem %[[#v4f32]]
-  ; CHECK: OpFRem %[[#v4f32]]
-  ; CHECK: OpFRem %[[#v4f32]]
+  ; CHECK: %[[#FRem1:]] = OpFRem %[[#v4f32]] %[[#FDiv1]]
+  ; CHECK: %[[#FRem2:]] = OpFRem %[[#v4f32]] %[[#FDiv2]]
+  ; CHECK: %[[#FRem3:]] = OpFRem %[[#v4f32]] %[[#FDiv3]]
+  ; CHECK: %[[#FRem4:]] = OpFRem %[[#v4f32]] %[[#FDiv4]]
   %7 = frem reassoc nnan ninf nsz arcp afn <16 x float> %6, splat (float 3.000000e+00)
 
-  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
-  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
-  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
-  ; CHECK: OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: %[[#Fma1:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv1]] %[[#FRem1]]
+  ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv2]] %[[#FRem2]]
+  ; CHECK: %[[#Fma3:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv3]] %[[#FRem3]]
+  ; CHECK: %[[#Fma4:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv4]] %[[#FRem4]]
   %8 = call reassoc nnan ninf nsz arcp afn <16 x float> @llvm.fma.v16f32(<16 x float> %5, <16 x float> %6, <16 x float> %7)
 
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+
   %9 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
   store <16 x float> %8, ptr addrspace(10) %9, align 4
   ret void
 }
 
-attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
\ No newline at end of file
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

>From 64afb6c91d4e01cdc3a570ed6c9be13f1b6e0a3f Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 10 Dec 2025 14:51:37 -0500
Subject: [PATCH 6/8] Handle G_STRICT_FMA

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  5 +-
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  |  1 +
 .../SPIRV/legalization/vector-arithmetic-6.ll | 35 ++++++++++++
 .../SPIRV/legalization/vector-arithmetic.ll   | 57 +++++++++++++++++++
 4 files changed, 94 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index a3ad5d7fd40c0..8bcc295208f67 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -184,7 +184,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
           .custom();
   }
 
-  getActionDefinitionsBuilder(TargetOpcode::G_FMA)
+  getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
       .legalFor(allScalars)
       .legalFor(allowedVectorTypes)
       .moreElementsToNextPow2(0)
@@ -295,9 +295,6 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .legalFor(allIntScalarsAndVectors)
       .legalIf(extendedScalarsAndVectors);
 
-  getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
-      .legalFor(allFloatScalarsAndVectors);
-
   getActionDefinitionsBuilder(G_STRICT_FLDEXP)
       .legalForCartesianProduct(allFloatScalarsAndVectors, allIntScalars);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 1170a7a2f3270..99edb937c3daa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -179,6 +179,7 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
     case TargetOpcode::G_FDIV:
     case TargetOpcode::G_FREM:
     case TargetOpcode::G_FMA:
+    case TargetOpcode::G_STRICT_FMA:
       ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
       break;
     case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
index 4f889a3043efc..131255396834c 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
@@ -153,4 +153,39 @@ entry:
   ret void
 }
 
+; Test constrained fma vector arithmetic operations
+define void @test_constrained_fma_vector() local_unnamed_addr #0 {
+; CHECK: OpFunction
+entry:
+  %2 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 0
+  %3 = load <6 x float>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 1
+  %5 = load <6 x float>, ptr addrspace(10) %4, align 4
+  %6 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f1, i32 0, i32 2
+  %7 = load <6 x float>, ptr addrspace(10) %6, align 4
+
+  ; CHECK: %[[#Fma1:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
+  %8 = call <6 x float> @llvm.experimental.constrained.fma.v6f32(<6 x float> %3, <6 x float> %5, <6 x float> %7, metadata !"round.dynamic", metadata !"fpexcept.strict")
+
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+
+  %9 = getelementptr [4 x [6 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
+  store <6 x float> %8, ptr addrspace(10) %9, align 4
+  ret void
+}
+
+declare <6 x float> @llvm.experimental.constrained.fma.v6f32(<6 x float>, <6 x float>, <6 x float>, metadata, metadata)
+
 attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
index e5476f5b15055..dd1e3d60a52bf 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic.ll
@@ -239,4 +239,61 @@ entry:
   ret void
 }
 
+; Test constrained fma vector arithmetic operations
+define void @test_constrained_fma_vector() local_unnamed_addr #0 {
+; CHECK: OpFunction
+entry:
+  %2 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 0
+  %3 = load <16 x float>, ptr addrspace(10) %2, align 4
+  %4 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 1
+  %5 = load <16 x float>, ptr addrspace(10) %4, align 4
+  %6 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f1, i32 0, i32 2
+  %7 = load <16 x float>, ptr addrspace(10) %6, align 4
+
+  ; CHECK: %[[#Fma1:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: %[[#Fma3:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: %[[#Fma4:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
+  %8 = call <16 x float> @llvm.experimental.constrained.fma.v16f32(<16 x float> %3, <16 x float> %5, <16 x float> %7, metadata !"round.dynamic", metadata !"fpexcept.strict")
+
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma2]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma3]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 0
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 1
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 2
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma4]] 3
+  ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
+
+  %9 = getelementptr [4 x [16 x float] ], ptr addrspace(10) @f2, i32 0, i32 0
+  store <16 x float> %8, ptr addrspace(10) %9, align 4
+  ret void
+}
+
+declare <16 x float> @llvm.experimental.constrained.fma.v16f32(<16 x float>, <16 x float>, <16 x float>, metadata, metadata)
+
 attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

>From 9ef180b05e883779720bfa54e1aebf8be60af979 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 10 Dec 2025 15:26:50 -0500
Subject: [PATCH 7/8] Scalarize rem and div with illegal vector size that is
 not a power of 2.

---
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  | 21 +++++-
 .../SPIRV/legalization/vector-arithmetic-6.ll | 69 ++++++++++++++-----
 2 files changed, 71 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 8bcc295208f67..55f17b13fd07f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -173,7 +173,15 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
 
   for (auto Opc : getTypeFoldingSupportedOpcodes()) {
-    if (Opc != G_EXTRACT_VECTOR_ELT)
+    switch (Opc) {
+    case G_EXTRACT_VECTOR_ELT:
+    case G_UREM:
+    case G_SREM:
+    case G_UDIV:
+    case G_SDIV:
+    case G_FREM:
+      break;
+    default:
       getActionDefinitionsBuilder(Opc)
           .customFor(allScalars)
           .customFor(allowedVectorTypes)
@@ -182,8 +190,19 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
                            LegalizeMutations::changeElementCountTo(
                                0, ElementCount::getFixed(MaxVectorSize)))
           .custom();
+      break;
+    }
   }
 
+  getActionDefinitionsBuilder({G_UREM, G_SREM, G_SDIV, G_UDIV, G_FREM})
+      .customFor(allScalars)
+      .customFor(allowedVectorTypes)
+      .scalarizeIf(numElementsNotPow2(0), 0)
+      .fewerElementsIf(vectorElementCountIsGreaterThan(0, MaxVectorSize),
+                       LegalizeMutations::changeElementCountTo(
+                           0, ElementCount::getFixed(MaxVectorSize)))
+      .custom();
+
   getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
       .legalFor(allScalars)
       .legalFor(allowedVectorTypes)
diff --git a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
index 131255396834c..d1cbfd4811c30 100644
--- a/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
+++ b/llvm/test/CodeGen/SPIRV/legalization/vector-arithmetic-6.ll
@@ -10,6 +10,7 @@
 ; CHECK-DAG: %[[#v6i32:]] = OpTypeArray %[[#int]] %[[#c6]]
 ; CHECK-DAG: %[[#ptr_ssbo_v6i32:]] = OpTypePointer Private %[[#v6i32]]
 ; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#int]] 4
+; CHECK-DAG: %[[#UNDEF:]] = OpUndef %[[#int]]
 
 @f1 = internal addrspace(10) global [4 x [6 x float] ] zeroinitializer
 @f2 = internal addrspace(10) global [4 x [6 x float] ] zeroinitializer
@@ -80,33 +81,61 @@ entry:
   ; CHECK: %[[#Mul2:]] = OpIMul %[[#v4i32]] %[[#Sub2]]
   %8 = mul <6 x i32> %7, splat (i32 2)
 
-  ; CHECK: %[[#SDiv1:]] = OpSDiv %[[#v4i32]] %[[#Mul1]]
-  ; CHECK: %[[#SDiv2:]] = OpSDiv %[[#v4i32]] %[[#Mul2]]
+  ; CHECK-DAG: %[[#E1:]] = OpCompositeExtract %[[#int]] %[[#Mul1]] 0
+  ; CHECK-DAG: %[[#E2:]] = OpCompositeExtract %[[#int]] %[[#Mul1]] 1
+  ; CHECK-DAG: %[[#E3:]] = OpCompositeExtract %[[#int]] %[[#Mul1]] 2
+  ; CHECK-DAG: %[[#E4:]] = OpCompositeExtract %[[#int]] %[[#Mul1]] 3
+  ; CHECK-DAG: %[[#E5:]] = OpCompositeExtract %[[#int]] %[[#Mul2]] 0
+  ; CHECK-DAG: %[[#E6:]] = OpCompositeExtract %[[#int]] %[[#Mul2]] 1
+  ; CHECK: %[[#SDiv1:]] = OpSDiv %[[#int]] %[[#E1]]
+  ; CHECK: %[[#SDiv2:]] = OpSDiv %[[#int]] %[[#E2]]
+  ; CHECK: %[[#SDiv3:]] = OpSDiv %[[#int]] %[[#E3]]
+  ; CHECK: %[[#SDiv4:]] = OpSDiv %[[#int]] %[[#E4]]
+  ; CHECK: %[[#SDiv5:]] = OpSDiv %[[#int]] %[[#E5]]
+  ; CHECK: %[[#SDiv6:]] = OpSDiv %[[#int]] %[[#E6]]
   %9 = sdiv <6 x i32> %8, splat (i32 2)
 
-  ; CHECK: %[[#UDiv1:]] = OpUDiv %[[#v4i32]] %[[#SDiv1]]
-  ; CHECK: %[[#UDiv2:]] = OpUDiv %[[#v4i32]] %[[#SDiv2]]
+  ; CHECK: %[[#UDiv1:]] = OpUDiv %[[#int]] %[[#SDiv1]]
+  ; CHECK: %[[#UDiv2:]] = OpUDiv %[[#int]] %[[#SDiv2]]
+  ; CHECK: %[[#UDiv3:]] = OpUDiv %[[#int]] %[[#SDiv3]]
+  ; CHECK: %[[#UDiv4:]] = OpUDiv %[[#int]] %[[#SDiv4]]
+  ; CHECK: %[[#UDiv5:]] = OpUDiv %[[#int]] %[[#SDiv5]]
+  ; CHECK: %[[#UDiv6:]] = OpUDiv %[[#int]] %[[#SDiv6]]
   %10 = udiv <6 x i32> %9, splat (i32 1)
 
-  ; CHECK: %[[#SRem1:]] = OpSRem %[[#v4i32]] %[[#UDiv1]]
-  ; CHECK: %[[#SRem2:]] = OpSRem %[[#v4i32]] %[[#UDiv2]]
+  ; CHECK: %[[#SRem1:]] = OpSRem %[[#int]] %[[#UDiv1]]
+  ; CHECK: %[[#SRem2:]] = OpSRem %[[#int]] %[[#UDiv2]]
+  ; CHECK: %[[#SRem3:]] = OpSRem %[[#int]] %[[#UDiv3]]
+  ; CHECK: %[[#SRem4:]] = OpSRem %[[#int]] %[[#UDiv4]]
+  ; CHECK: %[[#SRem5:]] = OpSRem %[[#int]] %[[#UDiv5]]
+  ; CHECK: %[[#SRem6:]] = OpSRem %[[#int]] %[[#UDiv6]]
   %11 = srem <6 x i32> %10, splat (i32 3)
 
-  ; CHECK: %[[#UMod1:]] = OpUMod %[[#v4i32]] %[[#SRem1]]
-  ; CHECK: %[[#UMod2:]] = OpUMod %[[#v4i32]] %[[#SRem2]]
+  ; CHECK: %[[#UMod1:]] = OpUMod %[[#int]] %[[#SRem1]]
+  ; CHECK: %[[#UMod2:]] = OpUMod %[[#int]] %[[#SRem2]]
+  ; CHECK: %[[#UMod3:]] = OpUMod %[[#int]] %[[#SRem3]]
+  ; CHECK: %[[#UMod4:]] = OpUMod %[[#int]] %[[#SRem4]]
+  ; CHECK: %[[#UMod5:]] = OpUMod %[[#int]] %[[#SRem5]]
+  ; CHECK: %[[#UMod6:]] = OpUMod %[[#int]] %[[#SRem6]]
   %12 = urem <6 x i32> %11, splat (i32 3)
 
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 0
+  ; CHECK: %[[#Construct1:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct1]] 0
   ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 1
+  ; CHECK: %[[#Construct2:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct2]] 1
   ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 2
+  ; CHECK: %[[#Construct3:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct3]] 2
   ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod1]] 3
+  ; CHECK: %[[#Construct4:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod1]] %[[#UMod2]] %[[#UMod3]] %[[#UMod4]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct4]] 3
   ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 0
+  ; CHECK: %[[#Construct5:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod5]] %[[#UMod6]] %[[#UNDEF]] %[[#UNDEF]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct5]] 0
   ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
-  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#UMod2]] 1
+  ; CHECK: %[[#Construct6:]] = OpCompositeConstruct %[[#v4i32]] %[[#UMod5]] %[[#UMod6]] %[[#UNDEF]] %[[#UNDEF]]
+  ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#int]] %[[#Construct6]] 1
   ; CHECK: OpStore {{.*}} %[[#EXTRACT]]
 
   %13 = getelementptr [4 x [6 x i32] ], ptr addrspace(10) @i2, i32 0, i32 0
@@ -127,12 +156,16 @@ entry:
   ; CHECK: %[[#FDiv2:]] = OpFDiv %[[#v4f32]]
   %6 = fdiv reassoc nnan ninf nsz arcp afn <6 x float> %3, splat (float 2.000000e+00)
 
-  ; CHECK: %[[#FRem1:]] = OpFRem %[[#v4f32]] %[[#FDiv1]]
-  ; CHECK: %[[#FRem2:]] = OpFRem %[[#v4f32]] %[[#FDiv2]]
+  ; CHECK: OpFRem %[[#float]]
+  ; CHECK: OpFRem %[[#float]]
+  ; CHECK: OpFRem %[[#float]]
+  ; CHECK: OpFRem %[[#float]]
+  ; CHECK: OpFRem %[[#float]]
+  ; CHECK: OpFRem %[[#float]]
   %7 = frem reassoc nnan ninf nsz arcp afn <6 x float> %6, splat (float 3.000000e+00)
 
-  ; CHECK: %[[#Fma1:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv1]] %[[#FRem1]]
-  ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma {{.*}} %[[#FDiv2]] %[[#FRem2]]
+  ; CHECK: %[[#Fma1:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
+  ; CHECK: %[[#Fma2:]] = OpExtInst %[[#v4f32]] {{.*}} Fma
   %8 = call reassoc nnan ninf nsz arcp afn <6 x float> @llvm.fma.v6f32(<6 x float> %5, <6 x float> %6, <6 x float> %7)
 
   ; CHECK: %[[#EXTRACT:]] = OpCompositeExtract %[[#float]] %[[#Fma1]] 0

>From f9d9c4e6bb9743805ce1f279990a23aba9695092 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Tue, 2 Dec 2025 16:07:14 -0500
Subject: [PATCH 8/8] [SPIRV] Implement lowering for llvm.matrix.transpose and
 llvm.matrix.multiply

This patch implements the lowering for the llvm.matrix.transpose and
llvm.matrix.multiply intrinsics in the SPIR-V backend.

- llvm.matrix.transpose is lowered to a G_SHUFFLE_VECTOR with a
  mask calculated to transpose the elements.
- llvm.matrix.multiply is lowered by decomposing the operation into
  dot products of rows and columns:
  - Rows and columns are extracted using G_UNMERGE_VALUES or shuffles.
  - Dot products are computed using OpDot for floating point vectors
    or standard arithmetic for scalars/integers.
  - The result is reconstructed using G_BUILD_VECTOR.

This change also updates SPIRVPostLegalizer to improve type deduction
for G_UNMERGE_VALUES, enabling correct type assignment for the
intermediate virtual registers generated during lowering.

New tests are added to verify support for various matrix sizes and
element types (float and int).
---
 llvm/lib/Target/SPIRV/SPIRVCombine.td         |  18 +-
 llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp | 189 ++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h   |  21 ++
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |   7 +-
 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp  |  95 +++++----
 .../SPIRV/llvm-intrinsics/matrix-multiply.ll  | 168 ++++++++++++++++
 .../SPIRV/llvm-intrinsics/matrix-transpose.ll | 124 ++++++++++++
 7 files changed, 567 insertions(+), 55 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCombine.td b/llvm/lib/Target/SPIRV/SPIRVCombine.td
index 991a5de1c4e83..7d69465de4ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCombine.td
+++ b/llvm/lib/Target/SPIRV/SPIRVCombine.td
@@ -22,8 +22,22 @@ def vector_select_to_faceforward_lowering : GICombineRule <
   (apply [{ Helper.applySPIRVFaceForward(*${root}); }])
 >;
 
+def matrix_transpose_lowering
+    : GICombineRule<(defs root:$root),
+                    (match (wip_match_opcode G_INTRINSIC):$root,
+                        [{ return Helper.matchMatrixTranspose(*${root}); }]),
+                    (apply [{ Helper.applyMatrixTranspose(*${root}); }])>;
+
+def matrix_multiply_lowering
+    : GICombineRule<(defs root:$root),
+                    (match (wip_match_opcode G_INTRINSIC):$root,
+                        [{ return Helper.matchMatrixMultiply(*${root}); }]),
+                    (apply [{ Helper.applyMatrixMultiply(*${root}); }])>;
+
 def SPIRVPreLegalizerCombiner
     : GICombiner<"SPIRVPreLegalizerCombinerImpl",
-                       [vector_length_sub_to_distance_lowering, vector_select_to_faceforward_lowering]> {
-    let CombineAllMethodName = "tryCombineAllImpl";
+                 [vector_length_sub_to_distance_lowering,
+                  vector_select_to_faceforward_lowering,
+                  matrix_transpose_lowering, matrix_multiply_lowering]> {
+  let CombineAllMethodName = "tryCombineAllImpl";
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp
index fad2b676fee04..693b74c1e06d7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp
@@ -7,9 +7,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVCombinerHelper.h"
+#include "SPIRVGlobalRegistry.h"
+#include "SPIRVUtils.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/LLVMContext.h" // Explicitly include for LLVMContext
 #include "llvm/Target/TargetMachine.h"
 
 using namespace llvm;
@@ -209,3 +213,188 @@ void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
   GR->invalidateMachineInstr(FalseInstr);
   FalseInstr->eraseFromParent();
 }
+
+bool SPIRVCombinerHelper::matchMatrixTranspose(MachineInstr &MI) const {
+  return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
+         cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_transpose;
+}
+
+void SPIRVCombinerHelper::applyMatrixTranspose(MachineInstr &MI) const {
+  Register ResReg = MI.getOperand(0).getReg();
+  Register InReg = MI.getOperand(2).getReg();
+  uint32_t Rows = MI.getOperand(3).getImm();
+  uint32_t Cols = MI.getOperand(4).getImm();
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  if (Rows == 1 && Cols == 1) {
+    Builder.buildCopy(ResReg, InReg);
+    MI.eraseFromParent();
+    return;
+  }
+
+  SmallVector<int, 16> Mask;
+  for (uint32_t K = 0; K < Rows * Cols; ++K) {
+    uint32_t R = K / Cols;
+    uint32_t C = K % Cols;
+    Mask.push_back(C * Rows + R);
+  }
+
+  Builder.buildShuffleVector(ResReg, InReg, InReg, Mask);
+  MI.eraseFromParent();
+}
+
+bool SPIRVCombinerHelper::matchMatrixMultiply(MachineInstr &MI) const {
+  return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
+         cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_multiply;
+}
+
+SmallVector<Register, 4>
+SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols,
+                                    SPIRVType *SpvColType,
+                                    SPIRVGlobalRegistry *GR) const {
+  // If the matrix is a single colunm, return that single column.
+  if (NumberOfCols == 1)
+    return {MatrixReg};
+
+  SmallVector<Register, 4> Cols;
+  LLT ColTy = GR->getRegType(SpvColType);
+  for (uint32_t J = 0; J < NumberOfCols; ++J)
+    Cols.push_back(MRI.createGenericVirtualRegister(ColTy));
+  Builder.buildUnmerge(Cols, MatrixReg);
+  for (Register R : Cols) {
+    setRegClassType(R, SpvColType, GR, &MRI, Builder.getMF());
+  }
+  return Cols;
+}
+
+SmallVector<Register, 4>
+SPIRVCombinerHelper::extractRows(Register MatrixReg, uint32_t NumRows,
+                                 uint32_t NumCols, SPIRVType *SpvRowType,
+                                 SPIRVGlobalRegistry *GR) const {
+  SmallVector<Register, 4> Rows;
+  LLT VecTy = GR->getRegType(SpvRowType);
+
+  // If there is only one column, then each row is a scalar that needs
+  // to be extracted.
+  if (NumCols == 1) {
+    assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector);
+    for (uint32_t I = 0; I < NumRows; ++I)
+      Rows.push_back(MRI.createGenericVirtualRegister(VecTy));
+    Builder.buildUnmerge(Rows, MatrixReg);
+    for (Register R : Rows) {
+      setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
+    }
+    return Rows;
+  }
+
+  // If the matrix is a single row return that row.
+  if (NumRows == 1) {
+    return {MatrixReg};
+  }
+
+  for (uint32_t I = 0; I < NumRows; ++I) {
+    SmallVector<int, 4> Mask;
+    for (uint32_t k = 0; k < NumCols; ++k)
+      Mask.push_back(k * NumRows + I);
+    Rows.push_back(Builder.buildShuffleVector(VecTy, MatrixReg, MatrixReg, Mask)
+                       .getReg(0));
+  }
+  for (Register R : Rows) {
+    setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
+  }
+  return Rows;
+}
+
+Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB,
+                                                SPIRVType *SpvVecType,
+                                                SPIRVGlobalRegistry *GR) const {
+  bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector;
+  SPIRVType *SpvScalarType = GR->getScalarOrVectorComponentType(SpvVecType);
+  bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat;
+  LLT VecTy = GR->getRegType(SpvVecType);
+
+  Register DotRes;
+  if (IsVectorOp) {
+    LLT ScalarTy = VecTy.getElementType();
+    Intrinsic::SPVIntrinsics DotIntrinsic =
+        (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
+    DotRes = Builder.buildIntrinsic(DotIntrinsic, {ScalarTy})
+                 .addUse(RowA)
+                 .addUse(ColB)
+                 .getReg(0);
+  } else {
+    if (IsFloatOp)
+      DotRes = Builder.buildFMul(VecTy, RowA, ColB).getReg(0);
+    else
+      DotRes = Builder.buildMul(VecTy, RowA, ColB).getReg(0);
+  }
+  setRegClassType(DotRes, SpvScalarType, GR, &MRI, Builder.getMF());
+  return DotRes;
+}
+
+SmallVector<Register, 16>
+SPIRVCombinerHelper::computeDotProducts(const SmallVector<Register, 4> &RowsA,
+                                        const SmallVector<Register, 4> &ColsB,
+                                        SPIRVType *SpvVecType,
+                                        SPIRVGlobalRegistry *GR) const {
+  SmallVector<Register, 16> ResultScalars;
+  for (uint32_t J = 0; J < ColsB.size(); ++J) {
+    for (uint32_t I = 0; I < RowsA.size(); ++I) {
+      ResultScalars.push_back(
+          computeDotProduct(RowsA[I], ColsB[J], SpvVecType, GR));
+    }
+  }
+  return ResultScalars;
+}
+
+SPIRVType *
+SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K,
+                                             SPIRVGlobalRegistry *GR) const {
+  // Loop over all non debug uses of ResReg
+  Type *ScalarResType = nullptr;
+  for (auto &UseMI : MRI.use_instructions(ResReg)) {
+    if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
+      continue;
+
+    if (!isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type))
+      continue;
+
+    Type *Ty = getMDOperandAsType(UseMI.getOperand(2).getMetadata(), 0);
+    if (Ty->isVectorTy())
+      ScalarResType = cast<VectorType>(Ty)->getElementType();
+    else
+      ScalarResType = Ty;
+    assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy());
+    break;
+  }
+  Type *VecType =
+      (K > 1 ? FixedVectorType::get(ScalarResType, K) : ScalarResType);
+  return GR->getOrCreateSPIRVType(VecType, Builder,
+                                  SPIRV::AccessQualifier::None, false);
+}
+
+void SPIRVCombinerHelper::applyMatrixMultiply(MachineInstr &MI) const {
+  Register ResReg = MI.getOperand(0).getReg();
+  Register AReg = MI.getOperand(2).getReg();
+  Register BReg = MI.getOperand(3).getReg();
+  uint32_t NumRowsA = MI.getOperand(4).getImm();
+  uint32_t NumColsA = MI.getOperand(5).getImm();
+  uint32_t NumColsB = MI.getOperand(6).getImm();
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  SPIRVGlobalRegistry *GR =
+      MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
+
+  SPIRVType *SpvVecType = getDotProductVectorType(ResReg, NumColsA, GR);
+  SmallVector<Register, 4> ColsB =
+      extractColumns(BReg, NumColsB, SpvVecType, GR);
+  SmallVector<Register, 4> RowsA =
+      extractRows(AReg, NumRowsA, NumColsA, SpvVecType, GR);
+  SmallVector<Register, 16> ResultScalars =
+      computeDotProducts(RowsA, ColsB, SpvVecType, GR);
+
+  Builder.buildBuildVector(ResReg, ResultScalars);
+  MI.eraseFromParent();
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h
index 3118cdc744b8f..b6b3b36f03ade 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h
+++ b/llvm/lib/Target/SPIRV/SPIRVCombinerHelper.h
@@ -33,6 +33,27 @@ class SPIRVCombinerHelper : public CombinerHelper {
   void applySPIRVDistance(MachineInstr &MI) const;
   bool matchSelectToFaceForward(MachineInstr &MI) const;
   void applySPIRVFaceForward(MachineInstr &MI) const;
+  bool matchMatrixTranspose(MachineInstr &MI) const;
+  void applyMatrixTranspose(MachineInstr &MI) const;
+  bool matchMatrixMultiply(MachineInstr &MI) const;
+  void applyMatrixMultiply(MachineInstr &MI) const;
+
+private:
+  SPIRVType *getDotProductVectorType(Register ResReg, uint32_t K,
+                                     SPIRVGlobalRegistry *GR) const;
+  SmallVector<Register, 4> extractColumns(Register BReg, uint32_t N,
+                                          SPIRVType *SpvVecType,
+                                          SPIRVGlobalRegistry *GR) const;
+  SmallVector<Register, 4> extractRows(Register AReg, uint32_t NumRows,
+                                       uint32_t NumCols, SPIRVType *SpvRowType,
+                                       SPIRVGlobalRegistry *GR) const;
+  SmallVector<Register, 16>
+  computeDotProducts(const SmallVector<Register, 4> &RowsA,
+                     const SmallVector<Register, 4> &ColsB,
+                     SPIRVType *SpvVecType, SPIRVGlobalRegistry *GR) const;
+  Register computeDotProduct(Register RowA, Register ColB,
+                             SPIRVType *SpvVecType,
+                             SPIRVGlobalRegistry *GR) const;
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 40cb7dda53ba4..a7626c3d43c34 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -173,6 +173,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   // non-shader contexts, vector sizes of 8 and 16 are also permitted, but
   // arbitrary sizes (e.g., 6 or 11) are not.
   uint32_t MaxVectorSize = ST.isShader() ? 4 : 16;
+  LLVM_DEBUG(dbgs() << "MaxVectorSize: " << MaxVectorSize << "\n");
 
   for (auto Opc : getTypeFoldingSupportedOpcodes()) {
     switch (Opc) {
@@ -221,8 +222,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .moreElementsToNextPow2(0)
       .lowerIf(vectorElementCountIsGreaterThan(0, MaxVectorSize))
       .moreElementsToNextPow2(1)
-      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize))
-      .alwaysLegal();
+      .lowerIf(vectorElementCountIsGreaterThan(1, MaxVectorSize));
 
   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
       .moreElementsToNextPow2(1)
@@ -263,8 +263,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
 
   // If the result is still illegal, the combiner should be able to remove it.
   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
-      .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes)
-      .moreElementsToNextPow2(0);
+      .legalForCartesianProduct(allowedVectorTypes, allowedVectorTypes);
 
   getActionDefinitionsBuilder(G_SPLAT_VECTOR)
       .legalFor(allowedVectorTypes)
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 99edb937c3daa..5f52f60da37e1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -51,54 +51,6 @@ static SPIRVType *deduceIntTypeFromResult(Register ResVReg,
   return GR->getOrCreateSPIRVIntegerType(Ty.getScalarSizeInBits(), MIB);
 }
 
-static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
-                                           SPIRVGlobalRegistry *GR) {
-  MachineRegisterInfo &MRI = MF.getRegInfo();
-  Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
-  SPIRVType *ScalarType = nullptr;
-  if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
-    assert(DefType->getOpcode() == SPIRV::OpTypeVector);
-    ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
-  }
-
-  if (!ScalarType) {
-    // If we could not deduce the type from the source, try to deduce it from
-    // the uses of the results.
-    for (unsigned i = 0; i < I->getNumDefs() && !ScalarType; ++i) {
-      for (const auto &Use :
-           MRI.use_nodbg_instructions(I->getOperand(i).getReg())) {
-        if (Use.getOpcode() != TargetOpcode::G_BUILD_VECTOR)
-          continue;
-
-        if (auto *VecType =
-                GR->getSPIRVTypeForVReg(Use.getOperand(0).getReg())) {
-          ScalarType = GR->getScalarOrVectorComponentType(VecType);
-          break;
-        }
-      }
-    }
-  }
-
-  if (!ScalarType)
-    return false;
-
-  for (unsigned i = 0; i < I->getNumDefs(); ++i) {
-    Register DefReg = I->getOperand(i).getReg();
-    if (GR->getSPIRVTypeForVReg(DefReg))
-      continue;
-
-    LLT DefLLT = MRI.getType(DefReg);
-    SPIRVType *ResType =
-        DefLLT.isVector()
-            ? GR->getOrCreateSPIRVVectorType(
-                  ScalarType, DefLLT.getNumElements(), *I,
-                  *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
-            : ScalarType;
-    setRegClassType(DefReg, ResType, GR, &MRI, MF);
-  }
-  return true;
-}
-
 static SPIRVType *deduceTypeFromSingleOperand(MachineInstr *I,
                                               MachineIRBuilder &MIB,
                                               SPIRVGlobalRegistry *GR,
@@ -179,6 +131,7 @@ static SPIRVType *deduceTypeFromUses(Register Reg, MachineFunction &MF,
     case TargetOpcode::G_FDIV:
     case TargetOpcode::G_FREM:
     case TargetOpcode::G_FMA:
+    case TargetOpcode::COPY:
     case TargetOpcode::G_STRICT_FMA:
       ResType = deduceTypeFromResultRegister(&Use, Reg, GR, MIB);
       break;
@@ -223,6 +176,50 @@ static SPIRVType *deduceResultTypeFromOperands(MachineInstr *I,
   }
 }
 
+static bool deduceAndAssignTypeForGUnmerge(MachineInstr *I, MachineFunction &MF,
+                                           SPIRVGlobalRegistry *GR,
+                                           MachineIRBuilder &MIB) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  Register SrcReg = I->getOperand(I->getNumOperands() - 1).getReg();
+  SPIRVType *ScalarType = nullptr;
+  if (SPIRVType *DefType = GR->getSPIRVTypeForVReg(SrcReg)) {
+    assert(DefType->getOpcode() == SPIRV::OpTypeVector);
+    ScalarType = GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+  }
+
+  if (!ScalarType) {
+    // If we could not deduce the type from the source, try to deduce it from
+    // the uses of the results.
+    for (unsigned i = 0; i < I->getNumDefs(); ++i) {
+      Register DefReg = I->getOperand(i).getReg();
+      ScalarType = deduceTypeFromUses(DefReg, MF, GR, MIB);
+      if (ScalarType) {
+        ScalarType = GR->getScalarOrVectorComponentType(ScalarType);
+        break;
+      }
+    }
+  }
+
+  if (!ScalarType)
+    return false;
+
+  for (unsigned i = 0; i < I->getNumOperands(); ++i) {
+    Register DefReg = I->getOperand(i).getReg();
+    if (GR->getSPIRVTypeForVReg(DefReg))
+      continue;
+
+    LLT DefLLT = MRI.getType(DefReg);
+    SPIRVType *ResType =
+        DefLLT.isVector()
+            ? GR->getOrCreateSPIRVVectorType(
+                  ScalarType, DefLLT.getNumElements(), *I,
+                  *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo())
+            : ScalarType;
+    setRegClassType(DefReg, ResType, GR, &MRI, MF);
+  }
+  return true;
+}
+
 static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
                                      SPIRVGlobalRegistry *GR,
                                      MachineIRBuilder &MIB) {
@@ -234,7 +231,7 @@ static bool deduceAndAssignSpirvType(MachineInstr *I, MachineFunction &MF,
   // unlike the other instructions which have a single result register. The main
   // deduction logic is designed for the single-definition case.
   if (I->getOpcode() == TargetOpcode::G_UNMERGE_VALUES)
-    return deduceAndAssignTypeForGUnmerge(I, MF, GR);
+    return deduceAndAssignTypeForGUnmerge(I, MF, GR, MIB);
 
   LLVM_DEBUG(dbgs() << "Inferring type from operands\n");
   SPIRVType *ResType = deduceResultTypeFromOperands(I, GR, MIB);
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll
new file mode 100644
index 0000000000000..4f8dfd0494009
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-multiply.ll
@@ -0,0 +1,168 @@
+; RUN: llc -O0 -mtriple=spirv1.5-unknown-vulkan1.2 %s -o - | FileCheck %s --check-prefixes=CHECK,VK1_1
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.5-unknown-vulkan1.2 %s -o - -filetype=obj | spirv-val --target-env vulkan1.2 %}
+
+; RUN: llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3 %s -o - | FileCheck %s --check-prefixes=CHECK,VK1_3
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3 %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
+
+ at private_v4f32 = internal addrspace(10) global [4 x float] poison
+ at private_v4i32 = internal addrspace(10) global [4 x i32] poison
+ at private_v6f32 = internal addrspace(10) global [6 x float] poison
+ at private_v2f32 = internal addrspace(10) global [2 x float] poison
+ at private_v1f32 = internal addrspace(10) global [1 x float] poison
+
+
+; CHECK-DAG: %[[Float_ID:[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: %[[V2F32_ID:[0-9]+]] = OpTypeVector %[[Float_ID]] 2
+; CHECK-DAG: %[[V3F32_ID:[0-9]+]] = OpTypeVector %[[Float_ID]] 3
+; CHECK-DAG: %[[V4F32_ID:[0-9]+]] = OpTypeVector %[[Float_ID]] 4
+; CHECK-DAG: %[[Int_ID:[0-9]+]] = OpTypeInt 32 0
+; CHECK-DAG: %[[V2I32_ID:[0-9]+]] = OpTypeVector %[[Int_ID]] 2
+; CHECK-DAG: %[[V4I32_ID:[0-9]+]] = OpTypeVector %[[Int_ID]] 4
+
+; Test Matrix Multiply 2x2 * 2x2 float
+; CHECK-LABEL: ; -- Begin function test_matrix_multiply_f32_2x2_2x2
+; CHECK:       %[[A:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] {{.*}} {{.*}} 3
+; CHECK:       %[[B:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] {{.*}} {{.*}} 3
+; CHECK-DAG:   %[[B_Col0:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[B]] %[[#]] 0 1
+; CHECK-DAG:   %[[B_Col1:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[B]] %[[#]] 2 3
+; CHECK-DAG:   %[[A_Row0:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[A]] %[[A]] 0 2
+; CHECK-DAG:   %[[A_Row1:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[A]] %[[A]] 1 3
+; CHECK-DAG:   %[[C00:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row0]] %[[B_Col0]]
+; CHECK-DAG:   %[[C10:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row1]] %[[B_Col0]]
+; CHECK-DAG:   %[[C01:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row0]] %[[B_Col1]]
+; CHECK-DAG:   %[[C11:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row1]] %[[B_Col1]]
+; CHECK:       OpCompositeConstruct %[[V4F32_ID]] %[[C00]] %[[C10]] %[[C01]] %[[C11]]
+define internal void @test_matrix_multiply_f32_2x2_2x2() {
+  %1 = load <4 x float>, ptr addrspace(10) @private_v4f32
+  %2 = load <4 x float>, ptr addrspace(10) @private_v4f32
+  %3 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %1, <4 x float> %2, i32 2, i32 2, i32 2)
+  store <4 x float> %3, ptr addrspace(10) @private_v4f32
+  ret void
+}
+
+; Test Matrix Multiply 2x2 * 2x2 int
+; CHECK-LABEL: ; -- Begin function test_matrix_multiply_i32_2x2_2x2
+; CHECK:       %[[A:[0-9]+]] = OpCompositeInsert %[[V4I32_ID]] {{.*}} {{.*}} 3
+; CHECK:       %[[B:[0-9]+]] = OpCompositeInsert %[[V4I32_ID]] {{.*}} {{.*}} 3
+; CHECK-DAG:   %[[B_Col0:[0-9]+]] = OpVectorShuffle %[[V2I32_ID]] %[[B]] %[[#]] 0 1
+; CHECK-DAG:   %[[B_Col1:[0-9]+]] = OpVectorShuffle %[[V2I32_ID]] %[[B]] %[[#]] 2 3
+; CHECK-DAG:   %[[A_Row0:[0-9]+]] = OpVectorShuffle %[[V2I32_ID]] %[[A]] %[[A]] 0 2
+; CHECK-DAG:   %[[A_Row1:[0-9]+]] = OpVectorShuffle %[[V2I32_ID]] %[[A]] %[[A]] 1 3
+;
+; -- C00 = dot(A_Row0, B_Col0)
+; VK1_1-DAG:   %[[Mul00:[0-9]+]] = OpIMul %[[V2I32_ID]] %[[A_Row0]] %[[B_Col0]]
+; VK1_1-DAG:   %[[E00_0:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul00]] 0
+; VK1_1-DAG:   %[[E00_1:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul00]] 1
+; VK1_1-DAG:   %[[C00:[0-9]+]] = OpIAdd %[[Int_ID]] %[[E00_0]] %[[E00_1]]
+; VK1_3-DAG:   %[[C00:[0-9]+]] = OpUDot %[[Int_ID]] %[[A_Row0]] %[[B_Col0]]
+;
+; -- C10 = dot(A_Row1, B_Col0)
+; VK1_1-DAG:   %[[Mul10:[0-9]+]] = OpIMul %[[V2I32_ID]] %[[A_Row1]] %[[B_Col0]]
+; VK1_1-DAG:   %[[E10_0:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul10]] 0
+; VK1_1-DAG:   %[[E10_1:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul10]] 1
+; VK1_1-DAG:   %[[C10:[0-9]+]] = OpIAdd %[[Int_ID]] %[[E10_0]] %[[E10_1]]
+; VK1_3-DAG:   %[[C10:[0-9]+]] = OpUDot %[[Int_ID]] %[[A_Row1]] %[[B_Col0]]
+;
+; -- C11 = dot(A_Row1, B_Col1)
+; VK1_1-DAG:   %[[Mul11:[0-9]+]] = OpIMul %[[V2I32_ID]] %[[A_Row1]] %[[B_Col1]]
+; VK1_1-DAG:   %[[E11_0:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul11]] 0
+; VK1_1-DAG:   %[[E11_1:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul11]] 1
+; VK1_1-DAG:   %[[C11:[0-9]+]] = OpIAdd %[[Int_ID]] %[[E11_0]] %[[E11_1]]
+; VK1_3-DAG:   %[[C11:[0-9]+]] = OpUDot %[[Int_ID]] %[[A_Row1]] %[[B_Col1]]
+;
+; -- C01 = dot(A_Row0, B_Col1)
+; VK1_1-DAG:   %[[Mul01:[0-9]+]] = OpIMul %[[V2I32_ID]] %[[A_Row0]] %[[B_Col1]]
+; VK1_1-DAG:   %[[E01_0:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul01]] 0
+; VK1_1-DAG:   %[[E01_1:[0-9]+]] = OpCompositeExtract %[[Int_ID]] %[[Mul01]] 1
+; VK1_1-DAG:   %[[C01:[0-9]+]] = OpIAdd %[[Int_ID]] %[[E01_0]] %[[E01_1]]
+; VK1_3-DAG:   %[[C01:[0-9]+]] = OpUDot %[[Int_ID]] %[[A_Row0]] %[[B_Col1]]
+;
+; CHECK:       OpCompositeConstruct %[[V4I32_ID]] %[[C00]] %[[C10]] %[[C01]] %[[C11]]
+define internal void @test_matrix_multiply_i32_2x2_2x2() {
+  %1 = load <4 x i32>, ptr addrspace(10) @private_v4i32
+  %2 = load <4 x i32>, ptr addrspace(10) @private_v4i32
+  %3 = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %1, <4 x i32> %2, i32 2, i32 2, i32 2)
+  store <4 x i32> %3, ptr addrspace(10) @private_v4i32
+  ret void
+}
+
+; Test Matrix Multiply 2x3 * 3x2 float (Result 2x2 float)
+; CHECK-LABEL: ; -- Begin function test_matrix_multiply_f32_2x3_3x2
+; CHECK-DAG:   %[[B:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]]
+; CHECK-DAG:   %[[A:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]]
+;
+; CHECK-DAG:   %[[B_Col0:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
+; CHECK-DAG:   %[[B_Col1:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
+; CHECK-DAG:   %[[A_Row0:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
+; CHECK-DAG:   %[[A_Row1:[0-9]+]] = OpCompositeConstruct %[[V3F32_ID]]
+;
+; CHECK-DAG:   %[[C00:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row0]] %[[B_Col0]]
+; CHECK-DAG:   %[[C10:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row1]] %[[B_Col0]]
+; CHECK-DAG:   %[[C01:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row0]] %[[B_Col1]]
+; CHECK-DAG:   %[[C11:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row1]] %[[B_Col1]]
+; CHECK:       OpCompositeConstruct %[[V4F32_ID]] %[[C00]] %[[C10]] %[[C01]] %[[C11]]
+define internal void @test_matrix_multiply_f32_2x3_3x2() {
+  %1 = load <6 x float>, ptr addrspace(10) @private_v6f32
+  %2 = load <6 x float>, ptr addrspace(10) @private_v6f32
+  %3 = call <4 x float> @llvm.matrix.multiply.v4f32.v6f32.v6f32(<6 x float> %1, <6 x float> %2, i32 2, i32 3, i32 2)
+  store <4 x float> %3, ptr addrspace(10) @private_v4f32
+  ret void
+}
+
+; Test Matrix Multiply 2x2 * 2x1 float (Result 2x1 vector)
+; CHECK-LABEL: ; -- Begin function test_matrix_multiply_f32_2x2_2x1_vec
+; CHECK:       %[[A:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] {{.*}} {{.*}} 3
+; CHECK:       %[[B:[0-9]+]] = OpCompositeInsert %[[V2F32_ID]] {{.*}} {{.*}} 1
+; CHECK-DAG:   %[[A_Row0:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[A]] %[[A]] 0 2
+; CHECK-DAG:   %[[A_Row1:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[A]] %[[A]] 1 3
+; CHECK-DAG:   %[[C00:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row0]] %[[B]]
+; CHECK-DAG:   %[[C10:[0-9]+]] = OpDot %[[Float_ID]] %[[A_Row1]] %[[B]]
+; CHECK:       OpCompositeConstruct %[[V2F32_ID]] %[[C00]] %[[C10]]
+define internal void @test_matrix_multiply_f32_2x2_2x1_vec() {
+  %1 = load <4 x float>, ptr addrspace(10) @private_v4f32
+  %2 = load <2 x float>, ptr addrspace(10) @private_v2f32
+  %3 = call <2 x float> @llvm.matrix.multiply.v2f32.v4f32.v2f32(<4 x float> %1, <2 x float> %2, i32 2, i32 2, i32 1)
+  store <2 x float> %3, ptr addrspace(10) @private_v2f32
+  ret void
+}
+
+; Test Matrix Multiply 1x2 * 2x2 float (Result 1x2 vector)
+; CHECK-LABEL: ; -- Begin function test_matrix_multiply_f32_1x2_2x2_vec
+; CHECK:       %[[A:[0-9]+]] = OpCompositeInsert %[[V2F32_ID]] {{.*}} {{.*}} 1
+; CHECK:       %[[B:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] {{.*}} {{.*}} 3
+; CHECK-DAG:   %[[B_Col0:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[B]] %[[#]] 0 1
+; CHECK-DAG:   %[[B_Col1:[0-9]+]] = OpVectorShuffle %[[V2F32_ID]] %[[B]] %[[#]] 2 3
+; CHECK-DAG:   %[[C00:[0-9]+]] = OpDot %[[Float_ID]] %[[A]] %[[B_Col0]]
+; CHECK-DAG:   %[[C01:[0-9]+]] = OpDot %[[Float_ID]] %[[A]] %[[B_Col1]]
+; CHECK:       OpCompositeConstruct %[[V2F32_ID]] %[[C00]] %[[C01]]
+define internal void @test_matrix_multiply_f32_1x2_2x2_vec() {
+  %1 = load <2 x float>, ptr addrspace(10) @private_v2f32
+  %2 = load <4 x float>, ptr addrspace(10) @private_v4f32
+  %3 = call <2 x float> @llvm.matrix.multiply.v2f32.v2f32.v4f32(<2 x float> %1, <4 x float> %2, i32 1, i32 2, i32 2)
+  store <2 x float> %3, ptr addrspace(10) @private_v2f32
+  ret void
+}
+
+; Test Matrix Multiply 1x2 * 2x1 float (Result 1x1 scalar - OpDot)
+; TODO(171175): The SPIR-V backend does not legalize single element vectors.
+; CHECK-DISABLE: ; -- Begin function test_matrix_multiply_f32_1x2_2x1_scalar
+; define internal void @test_matrix_multiply_f32_1x2_2x1_scalar() {
+;   %1 = load <2 x float>, ptr addrspace(10) @private_v2f32
+;   %2 = load <2 x float>, ptr addrspace(10) @private_v2f32
+;   %3 = call <1 x float> @llvm.matrix.multiply.v1f32.v2f32.v2f32(<2 x float> %1, <2 x float> %2, i32 1, i32 2, i32 1)
+;   store <1 x float> %3, ptr addrspace(10) @private_v1f32
+;   ret void
+; }
+
+define void @main() #0 {
+  ret void
+}
+
+declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)
+declare <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32, i32, i32)
+declare <4 x float> @llvm.matrix.multiply.v4f32.v6f32.v6f32(<6 x float>, <6 x float>, i32, i32, i32)
+declare <2 x float> @llvm.matrix.multiply.v2f32.v4f32.v2f32(<4 x float>, <2 x float>, i32, i32, i32)
+declare <2 x float> @llvm.matrix.multiply.v2f32.v2f32.v4f32(<2 x float>, <4 x float>, i32, i32, i32)
+; declare <1 x float> @llvm.matrix.multiply.v1f32.v2f32.v2f32(<2 x float>, <2 x float>, i32, i32, i32)
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll
new file mode 100644
index 0000000000000..3474fecae9957
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/matrix-transpose.ll
@@ -0,0 +1,124 @@
+; RUN: llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3 %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-unknown-vulkan1.3 %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
+
+ at private_v4f32 = internal addrspace(10) global [4 x float] poison
+ at private_v6f32 = internal addrspace(10) global [6 x float] poison
+ at private_v1f32 = internal addrspace(10) global [1 x float] poison
+
+; CHECK-DAG: %[[Float_ID:[0-9]+]] = OpTypeFloat 32
+; CHECK-DAG: %[[V4F32_ID:[0-9]+]] = OpTypeVector %[[Float_ID]] 4
+
+; Test Transpose 2x2 float
+; CHECK-LABEL: ; -- Begin function test_transpose_f32_2x2
+; CHECK: %[[Shuffle:[0-9]+]] = OpVectorShuffle %[[V4F32_ID]] {{.*}} 0 2 1 3
+define internal void @test_transpose_f32_2x2() {
+ %1 = load <4 x float>, ptr addrspace(10) @private_v4f32
+ %2 = call <4 x float> @llvm.matrix.transpose.v4f32.i32(<4 x float> %1, i32 2, i32 2)
+ store <4 x float> %2, ptr addrspace(10) @private_v4f32
+ ret void
+}
+
+; Test Transpose 2x3 float (Result is 3x2 float)
+; Note: We should add more code to the prelegalizer combiner to be able to remove the insert and extracts. 
+;       This test should reduce to a series of access chains, loads, and stores.
+; CHECK-LABEL: ; -- Begin function test_transpose_f32_2x3
+define internal void @test_transpose_f32_2x3() {
+; -- Load input 2x3 matrix elements
+; CHECK: %[[AccessChain1:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID:[0-9]+]] %[[private_v6f32:[0-9]+]] %[[int_0:[0-9]+]]
+; CHECK: %[[Load1:[0-9]+]] = OpLoad %[[Float_ID]] %[[AccessChain1]]
+; CHECK: %[[AccessChain2:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_1:[0-9]+]]
+; CHECK: %[[Load2:[0-9]+]] = OpLoad %[[Float_ID]] %[[AccessChain2]]
+; CHECK: %[[AccessChain3:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_2:[0-9]+]]
+; CHECK: %[[Load3:[0-9]+]] = OpLoad %[[Float_ID]] %[[AccessChain3]]
+; CHECK: %[[AccessChain4:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_3:[0-9]+]]
+; CHECK: %[[Load4:[0-9]+]] = OpLoad %[[Float_ID]] %[[AccessChain4]]
+; CHECK: %[[AccessChain5:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_4:[0-9]+]]
+; CHECK: %[[Load5:[0-9]+]] = OpLoad %[[Float_ID]] %[[AccessChain5]]
+; CHECK: %[[AccessChain6:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_5:[0-9]+]]
+; CHECK: %[[Load6:[0-9]+]] = OpLoad %[[Float_ID]] %[[AccessChain6]]
+;
+; -- Construct intermediate vectors
+; CHECK: %[[CompositeInsert1:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load1]] %[[undef_V4F32_ID:[0-9]+]] 0
+; CHECK: %[[CompositeInsert2:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load2]] %[[CompositeInsert1]] 1
+; CHECK: %[[CompositeInsert3:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load3]] %[[CompositeInsert2]] 2
+; CHECK: %[[CompositeInsert4:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load4]] %[[CompositeInsert3]] 3
+; CHECK: %[[CompositeInsert5:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load5]] %[[undef_V4F32_ID]] 0
+; CHECK: %[[CompositeInsert6:[0-9]+]] = OpCompositeInsert %[[V4F32_ID]] %[[Load6]] %[[CompositeInsert5]] 1
+  %1 = load <6 x float>, ptr addrspace(10) @private_v6f32
+
+; -- Extract elements for transposition
+; CHECK: %[[Extract1:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 0
+; CHECK: %[[Extract2:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 2
+; CHECK: %[[Extract3:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert6]] 0
+; CHECK: %[[Extract4:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 1
+; CHECK: %[[Extract5:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert4]] 3
+; CHECK: %[[Extract6:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeInsert6]] 1
+  %2 = call <6 x float> @llvm.matrix.transpose.v6f32.i32(<6 x float> %1, i32 2, i32 3)
+
+; -- Store output 3x2 matrix elements
+; CHECK: %[[AccessChain7:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_0]]
+; CHECK: %[[CompositeConstruct1:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
+; CHECK: %[[Extract7:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct1]] 0
+; CHECK: OpStore %[[AccessChain7]] %[[Extract7]]
+; CHECK: %[[AccessChain8:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_1]]
+; CHECK: %[[CompositeConstruct2:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
+; CHECK: %[[Extract8:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct2]] 1
+; CHECK: OpStore %[[AccessChain8]] %[[Extract8]]
+; CHECK: %[[AccessChain9:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_2]]
+; CHECK: %[[CompositeConstruct3:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
+; CHECK: %[[Extract9:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct3]] 2
+; CHECK: OpStore %[[AccessChain9]] %[[Extract9]]
+; CHECK: %[[AccessChain10:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_3]]
+; CHECK: %[[CompositeConstruct4:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract1]] %[[Extract2]] %[[Extract3]] %[[Extract4]]
+; CHECK: %[[Extract10:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct4]] 3
+; CHECK: OpStore %[[AccessChain10]] %[[Extract10]]
+; CHECK: %[[AccessChain11:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_4]]
+; CHECK: %[[CompositeConstruct5:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract5]] %[[Extract6]] %[[undef_Float_ID:[0-9]+]] %[[undef_Float_ID]]
+; CHECK: %[[Extract11:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct5]] 0
+; CHECK: OpStore %[[AccessChain11]] %[[Extract11]]
+; CHECK: %[[AccessChain12:[0-9]+]] = OpAccessChain %[[_ptr_Float_ID]] %[[private_v6f32]] %[[int_5]]
+; CHECK: %[[CompositeConstruct6:[0-9]+]] = OpCompositeConstruct %[[V4F32_ID]] %[[Extract5]] %[[Extract6]] %[[undef_Float_ID]] %[[undef_Float_ID]]
+; CHECK: %[[Extract12:[0-9]+]] = OpCompositeExtract %[[Float_ID]] %[[CompositeConstruct6]] 1
+; CHECK: OpStore %[[AccessChain12]] %[[Extract12]]
+  store <6 x float> %2, ptr addrspace(10) @private_v6f32
+  ret void
+}
+
+; Test Transpose 1x4 float (Result is 4x1 float), should be a copy (vector of 4 floats)
+; CHECK-LABEL: ; -- Begin function test_transpose_f32_1x4_to_4x1
+; CHECK: %[[Shuffle:[0-9]+]] = OpVectorShuffle %[[V4F32_ID]] {{.*}} 0 1 2 3
+define internal void @test_transpose_f32_1x4_to_4x1() {
+ %1 = load <4 x float>, ptr addrspace(10) @private_v4f32
+ %2 = call <4 x float> @llvm.matrix.transpose.v4f32.i32(<4 x float> %1, i32 1, i32 4)
+ store <4 x float> %2, ptr addrspace(10) @private_v4f32
+ ret void
+}
+
+; Test Transpose 4x1 float (Result is 1x4 float), should be a copy (vector of 4 floats)
+; CHECK-LABEL: ; -- Begin function test_transpose_f32_4x1_to_1x4
+; CHECK: %[[Shuffle:[0-9]+]] = OpVectorShuffle %[[V4F32_ID]] {{.*}} 0 1 2 3
+define internal void @test_transpose_f32_4x1_to_1x4() {
+ %1 = load <4 x float>, ptr addrspace(10) @private_v4f32
+ %2 = call <4 x float> @llvm.matrix.transpose.v4f32.i32(<4 x float> %1, i32 4, i32 1)
+ store <4 x float> %2, ptr addrspace(10) @private_v4f32
+ ret void
+}
+
+; Test Transpose 1x1 float (Result is 1x1 float), should be a copy (scalar float)
+; TODO(171175): The SPIR-V backend does not seem to be legalizing single element vectors.
+; define internal void @test_transpose_f32_1x1() {
+;   %1 = load <1 x float>, ptr addrspace(10) @private_v1f32
+;   %2 = call <1 x float> @llvm.matrix.transpose.v1f32.i32(<1 x float> %1, i32 1, i32 1)
+;   store <1 x float> %2, ptr addrspace(10) @private_v1f32
+;   ret void
+; }
+
+define void @main() #0 {
+  ret void
+}
+
+declare <4 x float> @llvm.matrix.transpose.v4f32.i32(<4 x float>, i32, i32)
+declare <6 x float> @llvm.matrix.transpose.v6f32.i32(<6 x float>, i32, i32)
+; declare <1 x float> @llvm.matrix.transpose.v1f32.i32(<1 x float>, i32, i32)
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }



More information about the llvm-commits mailing list