[llvm] [SPIRV] Added support for the constrained arithmetic intrinsic (PR #157441)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 8 05:39:36 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-support
Author: Subash B (SubashBoopathi)
<details>
<summary>Changes</summary>
Added SPIR-V support for constrained arithmetic intrinsic fmuladd, with legalization, instruction selection, and tests; lowered as a sequence of OpFMul and OpFAdd, consistent with the SPIR-V translator.
---
Full diff: https://github.com/llvm/llvm-project/pull/157441.diff
6 Files Affected:
- (modified) llvm/include/llvm/Support/TargetOpcodes.def (+1)
- (modified) llvm/include/llvm/Target/GenericOpcodes.td (+1)
- (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+2)
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+48)
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+1-1)
- (modified) llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll (+4)
``````````diff
diff --git a/llvm/include/llvm/Support/TargetOpcodes.def b/llvm/include/llvm/Support/TargetOpcodes.def
index b905576b61791..dcf80c56c47b3 100644
--- a/llvm/include/llvm/Support/TargetOpcodes.def
+++ b/llvm/include/llvm/Support/TargetOpcodes.def
@@ -643,6 +643,7 @@ HANDLE_TARGET_OPCODE(G_FMA)
/// Generic FP multiply and add. Behaves as separate fmul and fadd.
HANDLE_TARGET_OPCODE(G_FMAD)
+HANDLE_TARGET_OPCODE(G_STRICT_FMULADD)
/// Generic FP division.
HANDLE_TARGET_OPCODE(G_FDIV)
diff --git a/llvm/include/llvm/Target/GenericOpcodes.td b/llvm/include/llvm/Target/GenericOpcodes.td
index ce4750db88c9a..08fbd2253edf2 100644
--- a/llvm/include/llvm/Target/GenericOpcodes.td
+++ b/llvm/include/llvm/Target/GenericOpcodes.td
@@ -1716,6 +1716,7 @@ def G_STRICT_FREM : ConstrainedInstruction<G_FREM>;
def G_STRICT_FMA : ConstrainedInstruction<G_FMA>;
def G_STRICT_FSQRT : ConstrainedInstruction<G_FSQRT>;
def G_STRICT_FLDEXP : ConstrainedInstruction<G_FLDEXP>;
+def G_STRICT_FMULADD : ConstrainedInstruction<G_FMAD>;
//------------------------------------------------------------------------------
// Memory intrinsics
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index d7280eaba2440..3742df2f1732b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2061,6 +2061,8 @@ static unsigned getConstrainedOpcode(Intrinsic::ID ID) {
return TargetOpcode::G_STRICT_FSQRT;
case Intrinsic::experimental_constrained_ldexp:
return TargetOpcode::G_STRICT_FLDEXP;
+ case Intrinsic::experimental_constrained_fmuladd:
+ return TargetOpcode::G_STRICT_FMULADD;
default:
return 0;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 6608b3f2cbefd..9d56c18f1c096 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -227,6 +227,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectExt(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
bool IsSigned) const;
+ bool selectStrictFMulAdd(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectTrunc(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -689,6 +692,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_FMA:
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
+ case TargetOpcode::G_STRICT_FMULADD:
+ return selectStrictFMulAdd(ResVReg, ResType, I);
+
case TargetOpcode::G_STRICT_FLDEXP:
return selectExtInst(ResVReg, ResType, I, CL::ldexp);
@@ -1038,6 +1044,48 @@ bool SPIRVInstructionSelector::selectOpWithSrcs(Register ResVReg,
return MIB.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectStrictFMulAdd(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ assert(I.getNumOperands() == 4 &&
+ "FMulAdd should have 3 operands and result");
+ assert(I.getOperand(1).isReg() && I.getOperand(2).isReg() &&
+ I.getOperand(3).isReg() && "Operands should be registers");
+ MachineBasicBlock &BB = *I.getParent();
+ Register MulLHS = I.getOperand(1).getReg();
+ Register MulRHS = I.getOperand(2).getReg();
+ Register AddRHS = I.getOperand(3).getReg();
+ SPIRVType *MulLHSType = GR.getSPIRVTypeForVReg(MulLHS);
+ SPIRVType *MulRHSType = GR.getSPIRVTypeForVReg(MulRHS);
+ SPIRVType *AddRHSType = GR.getSPIRVTypeForVReg(AddRHS);
+ if (!MulLHSType || !MulRHSType || !AddRHSType)
+ report_fatal_error("Input Type could not be determined.");
+ if (!GR.isScalarOrVectorOfType(MulLHS, SPIRV::OpTypeFloat) ||
+ !GR.isScalarOrVectorOfType(MulRHS, SPIRV::OpTypeFloat) ||
+ !GR.isScalarOrVectorOfType(AddRHS, SPIRV::OpTypeFloat)) {
+ report_fatal_error("FMulAdd requires floating-point operands");
+ }
+ bool IsScalar = (MulLHSType->getOpcode() == SPIRV::OpTypeFloat);
+ bool IsVector = (MulLHSType->getOpcode() == SPIRV::OpTypeVector);
+ if (!IsScalar && !IsVector)
+ report_fatal_error("Unsupported type for FMulAdd operation");
+ unsigned MulOpcode = IsScalar ? SPIRV::OpFMulS : SPIRV::OpFMulV;
+ unsigned AddOpcode = IsScalar ? SPIRV::OpFAddS : SPIRV::OpFAddV;
+ Register MulTemp = MRI->createVirtualRegister(MRI->getRegClass(MulLHS));
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(MulOpcode))
+ .addDef(MulTemp)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(MulLHS)
+ .addUse(MulRHS)
+ .constrainAllUses(TII, TRI, RBI);
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(AddOpcode))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(MulTemp)
+ .addUse(AddRHS)
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 721f64a329d31..8041ef67cfa56 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -193,7 +193,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
.legalFor(allIntScalarsAndVectors)
.legalIf(extendedScalarsAndVectors);
- getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA})
+ getActionDefinitionsBuilder({G_FMA, G_STRICT_FMA, G_STRICT_FMULADD})
.legalFor(allFloatScalarsAndVectors);
getActionDefinitionsBuilder(G_STRICT_FLDEXP)
diff --git a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll
index 11bedfa605f9b..b2b8bbfce7321 100644
--- a/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll
+++ b/llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll
@@ -7,6 +7,7 @@
; CHECK-DAG: OpName %[[#r4:]] "r4"
; CHECK-DAG: OpName %[[#r5:]] "r5"
; CHECK-DAG: OpName %[[#r6:]] "r6"
+; CHECK-DAG: OpName %[[#r7:]] "r7"
; CHECK-NOT: OpDecorate %[[#r5]] FPRoundingMode
; CHECK-NOT: OpDecorate %[[#r6]] FPRoundingMode
@@ -22,6 +23,8 @@
; CHECK: OpFMul %[[#]] %[[#]]
; CHECK: OpExtInst %[[#]] %[[#]] fma %[[#]] %[[#]] %[[#]]
; CHECK: OpFRem
+; CHECK: OpFMul %[[#]] %[[#]]
+; CHECK: OpFAdd %[[#]] %[[#]]
; Function Attrs: norecurse nounwind strictfp
define dso_local spir_kernel void @test(float %a, i32 %in, i32 %ui) {
@@ -32,6 +35,7 @@ entry:
%r4 = tail call float @llvm.experimental.constrained.fmul.f32(float %a, float %a, metadata !"round.downward", metadata !"fpexcept.strict")
%r5 = tail call float @llvm.experimental.constrained.fma.f32(float %a, float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict")
%r6 = tail call float @llvm.experimental.constrained.frem.f32(float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict")
+ %r7 = tail call float @llvm.experimental.constrained.fmuladd.f32(float %a, float %a, float %a, metadata !"round.dynamic", metadata !"fpexcept.strict")
ret void
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/157441
More information about the llvm-commits
mailing list