[llvm] [SPIRV] Implement lowering for llvm.matrix.transpose and llvm.matrix.multiply (PR #172050)

Viktoria Maximova via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 22 03:33:54 PST 2026


================
@@ -209,3 +213,188 @@ void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
   GR->invalidateMachineInstr(FalseInstr);
   FalseInstr->eraseFromParent();
 }
+
+bool SPIRVCombinerHelper::matchMatrixTranspose(MachineInstr &MI) const {
+  return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
+         cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_transpose;
+}
+
+void SPIRVCombinerHelper::applyMatrixTranspose(MachineInstr &MI) const {
+  Register ResReg = MI.getOperand(0).getReg();
+  Register InReg = MI.getOperand(2).getReg();
+  uint32_t Rows = MI.getOperand(3).getImm();
+  uint32_t Cols = MI.getOperand(4).getImm();
+
+  Builder.setInstrAndDebugLoc(MI);
+
+  if (Rows == 1 && Cols == 1) {
+    Builder.buildCopy(ResReg, InReg);
+    MI.eraseFromParent();
+    return;
+  }
+
+  SmallVector<int, 16> Mask;
+  for (uint32_t K = 0; K < Rows * Cols; ++K) {
+    uint32_t R = K / Cols;
+    uint32_t C = K % Cols;
+    Mask.push_back(C * Rows + R);
+  }
+
+  Builder.buildShuffleVector(ResReg, InReg, InReg, Mask);
+  MI.eraseFromParent();
+}
+
+bool SPIRVCombinerHelper::matchMatrixMultiply(MachineInstr &MI) const {
+  return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
+         cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_multiply;
+}
+
+SmallVector<Register, 4>
+SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols,
+                                    SPIRVType *SpvColType,
+                                    SPIRVGlobalRegistry *GR) const {
+  // If the matrix is a single colunm, return that single column.
+  if (NumberOfCols == 1)
+    return {MatrixReg};
+
+  SmallVector<Register, 4> Cols;
+  LLT ColTy = GR->getRegType(SpvColType);
+  for (uint32_t J = 0; J < NumberOfCols; ++J)
+    Cols.push_back(MRI.createGenericVirtualRegister(ColTy));
+  Builder.buildUnmerge(Cols, MatrixReg);
+  for (Register R : Cols) {
+    setRegClassType(R, SpvColType, GR, &MRI, Builder.getMF());
+  }
+  return Cols;
+}
+
+SmallVector<Register, 4>
+SPIRVCombinerHelper::extractRows(Register MatrixReg, uint32_t NumRows,
+                                 uint32_t NumCols, SPIRVType *SpvRowType,
+                                 SPIRVGlobalRegistry *GR) const {
+  SmallVector<Register, 4> Rows;
+  LLT VecTy = GR->getRegType(SpvRowType);
+
+  // If there is only one column, then each row is a scalar that needs
+  // to be extracted.
+  if (NumCols == 1) {
+    assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector);
+    for (uint32_t I = 0; I < NumRows; ++I)
+      Rows.push_back(MRI.createGenericVirtualRegister(VecTy));
+    Builder.buildUnmerge(Rows, MatrixReg);
+    for (Register R : Rows) {
+      setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
+    }
+    return Rows;
+  }
+
+  // If the matrix is a single row return that row.
+  if (NumRows == 1) {
+    return {MatrixReg};
+  }
+
+  for (uint32_t I = 0; I < NumRows; ++I) {
+    SmallVector<int, 4> Mask;
+    for (uint32_t k = 0; k < NumCols; ++k)
+      Mask.push_back(k * NumRows + I);
+    Rows.push_back(Builder.buildShuffleVector(VecTy, MatrixReg, MatrixReg, Mask)
+                       .getReg(0));
+  }
+  for (Register R : Rows) {
+    setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
+  }
+  return Rows;
+}
+
+Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB,
+                                                SPIRVType *SpvVecType,
+                                                SPIRVGlobalRegistry *GR) const {
+  bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector;
+  SPIRVType *SpvScalarType = GR->getScalarOrVectorComponentType(SpvVecType);
+  bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat;
+  LLT VecTy = GR->getRegType(SpvVecType);
+
+  Register DotRes;
+  if (IsVectorOp) {
+    LLT ScalarTy = VecTy.getElementType();
+    Intrinsic::SPVIntrinsics DotIntrinsic =
+        (IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
+    DotRes = Builder.buildIntrinsic(DotIntrinsic, {ScalarTy})
+                 .addUse(RowA)
+                 .addUse(ColB)
+                 .getReg(0);
+  } else {
+    if (IsFloatOp)
+      DotRes = Builder.buildFMul(VecTy, RowA, ColB).getReg(0);
+    else
+      DotRes = Builder.buildMul(VecTy, RowA, ColB).getReg(0);
+  }
+  setRegClassType(DotRes, SpvScalarType, GR, &MRI, Builder.getMF());
+  return DotRes;
+}
+
+SmallVector<Register, 16>
+SPIRVCombinerHelper::computeDotProducts(const SmallVector<Register, 4> &RowsA,
+                                        const SmallVector<Register, 4> &ColsB,
+                                        SPIRVType *SpvVecType,
+                                        SPIRVGlobalRegistry *GR) const {
+  SmallVector<Register, 16> ResultScalars;
+  for (uint32_t J = 0; J < ColsB.size(); ++J) {
+    for (uint32_t I = 0; I < RowsA.size(); ++I) {
+      ResultScalars.push_back(
+          computeDotProduct(RowsA[I], ColsB[J], SpvVecType, GR));
+    }
+  }
+  return ResultScalars;
+}
+
+SPIRVType *
+SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K,
+                                             SPIRVGlobalRegistry *GR) const {
+  // Loop over all non debug uses of ResReg
+  Type *ScalarResType = nullptr;
+  for (auto &UseMI : MRI.use_instructions(ResReg)) {
+    if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
+      continue;
+
+    if (!isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type))
+      continue;
+
+    Type *Ty = getMDOperandAsType(UseMI.getOperand(2).getMetadata(), 0);
+    if (Ty->isVectorTy())
+      ScalarResType = cast<VectorType>(Ty)->getElementType();
+    else
+      ScalarResType = Ty;
+    assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy());
+    break;
+  }
+  Type *VecType =
+      (K > 1 ? FixedVectorType::get(ScalarResType, K) : ScalarResType);
----------------
vmaksimo wrote:

Hi @s-perron! Static analysis tool reports that we can meet null pointer dereference of `ScalarResType` here.
I believe just inserting `assert` before this line is not the right thing to do. Could you please suggest the best option here? I could create a PR myself if it'd be easier to follow-up.

https://github.com/llvm/llvm-project/pull/172050


More information about the llvm-commits mailing list