[Mlir-commits] [mlir] f9c0cf5 - [mlir][spirv] Lower math.ctlz to OpenCL.std clz for Kernel targets (#195470)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 6 01:21:56 PDT 2026
Author: Levin Dabhi
Date: 2026-05-06T09:21:51+01:00
New Revision: f9c0cf57d3339636a9aeeb0d4746de0659636e0b
URL: https://github.com/llvm/llvm-project/commit/f9c0cf57d3339636a9aeeb0d4746de0659636e0b
DIFF: https://github.com/llvm/llvm-project/commit/f9c0cf57d3339636a9aeeb0d4746de0659636e0b.diff
LOG: [mlir][spirv] Lower math.ctlz to OpenCL.std clz for Kernel targets (#195470)
Lower `math.ctlz` to `spirv.CL.Clz` for targets with Kernel capability.
Shader targets keep the existing GLSL-based fallback implemented via
`spirv.GL.FindUMsb`.
Previously, `math.ctlz` was lowered through the GLSL path using
`spirv.GL.FindUMsb` plus additional SPIR-V ops. That worked for Shader
targets, but failed legalization for OpenCL/Kernel targets where Shader
capability is not supported.
Added:
Modified:
mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index e4b5da7a5ea92..ea6be76373573 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -183,9 +183,12 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
/// Converts math.ctlz to SPIR-V ops.
///
-/// SPIR-V does not have a direct operations for counting leading zeros. If
-/// Shader capability is supported, we can leverage GL FindUMsb to calculate
-/// it.
+/// OpenCL targets lower math.ctlz directly to OpenCL.std clz via the generic
+/// elementwise pattern. This pattern handles the shader fallback.
+///
+/// SPIR-V does not have a direct operations for counting leading zeros for
+/// glsl. If Shader capability is supported, we can leverage GL FindUMsb to
+/// calculate it.
struct CountLeadingZerosPattern final
: public OpConversionPattern<math::CountLeadingZerosOp> {
using Base::Base;
@@ -200,7 +203,11 @@ struct CountLeadingZerosPattern final
if (!type)
return failure();
- // We can only support 32-bit integer types for now.
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ if (!typeConverter.getTargetEnv().allows(spirv::Capability::Shader))
+ return rewriter.notifyMatchFailure(countOp, "requires Shader capability");
+
+ // The GL FindUMsb fallback only supports 32-bit integer types for now.
unsigned bitwidth = 0;
if (isa<IntegerType>(type))
bitwidth = type.getIntOrFloatBitWidth();
@@ -533,35 +540,37 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
typeConverter, patterns.getContext());
// OpenCL patterns
- patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
- Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
- Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
- CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
- CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
- CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
- CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
- CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
- CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
- CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
- CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
- CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
- CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
- CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
- CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
- CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
- CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
- CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
- CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
- CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
- CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
- CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
- CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
- CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
- CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
- CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
- CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
- CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
- CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
+ patterns.add<
+ Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
+ Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
+ Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
+ CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
+ CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
+ CheckedElementwiseOpPattern<math::CountLeadingZerosOp, spirv::CLClzOp>,
+ CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
+ CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
+ CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
+ CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
+ CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
+ CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
+ CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
+ CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
+ CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
+ CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
+ CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
+ CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
+ CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
+ CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
+ CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
+ CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
+ CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
+ CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
+ CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
+ CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
+ CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
+ CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
+ CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
+ CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index 56a0d4dafec8c..dae1b43402718 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -159,6 +159,15 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
func.func @int_unary(%arg0: i32) {
// CHECK: spirv.CL.s_abs %{{.*}}
%0 = math.absi %arg0 : i32
+ // CHECK: spirv.CL.clz %{{.*}} : i32
+ %1 = math.ctlz %arg0 : i32
+ return
+}
+
+// CHECK-LABEL: @int_unary_vector
+func.func @int_unary_vector(%arg0: vector<2xi32>) {
+ // CHECK: spirv.CL.clz %{{.*}} : vector<2xi32>
+ %0 = math.ctlz %arg0 : vector<2xi32>
return
}
More information about the Mlir-commits
mailing list