[Mlir-commits] [mlir] [mlir][SPIR-V] Lower math.{exp2, log2, log10} operations (PR #196723)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Mon May 11 02:49:14 PDT 2026
https://github.com/aobolensk updated https://github.com/llvm/llvm-project/pull/196723
>From 99fa90da172e3a0dbfa177dfffae13bcd46fa160 Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Sat, 9 May 2026 15:24:50 +0200
Subject: [PATCH] [mlir][SPIR-V] Add CL.{exp2,exp10,log2,log10} and lower
math.{exp2,log2,log10}
---
.../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td | 84 +++++++++++++++++++
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 5 +-
.../MathToSPIRV/math-to-opencl-spirv.mlir | 20 ++---
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 72 ++++++++++++++++
mlir/test/Target/SPIRV/ocl-ops.mlir | 8 ++
5 files changed, 175 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 71f9c9579db81..37989e6e7e54a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -386,6 +386,48 @@ def SPIRV_CLExpOp : SPIRV_CLUnaryArithmeticOp<"exp", 19, SPIRV_Float> {
// -----
+def SPIRV_CLExp2Op : SPIRV_CLUnaryArithmeticOp<"exp2", 20, SPIRV_Float> {
+ let summary = "Compute the base-2 exponential of x.";
+
+ let description = [{
+ Result Type and x must be floating-point or vector(2,3,4,8,16) of
+ floating-point values.
+
+ All of the operands, including the Result Type operand,
+ must be of the same type.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.exp2 %0 : f32
+ %3 = spirv.CL.exp2 %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_CLExp10Op : SPIRV_CLUnaryArithmeticOp<"exp10", 21, SPIRV_Float> {
+ let summary = "Compute the base-10 exponential of x.";
+
+ let description = [{
+ Result Type and x must be floating-point or vector(2,3,4,8,16) of
+ floating-point values.
+
+ All of the operands, including the Result Type operand,
+ must be of the same type.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.exp10 %0 : f32
+ %3 = spirv.CL.exp10 %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_CLFAbsOp : SPIRV_CLUnaryArithmeticOp<"fabs", 23, SPIRV_Float> {
let summary = "Absolute value of operand";
@@ -568,6 +610,48 @@ def SPIRV_CLLogOp : SPIRV_CLUnaryArithmeticOp<"log", 37, SPIRV_Float> {
// -----
+def SPIRV_CLLog2Op : SPIRV_CLUnaryArithmeticOp<"log2", 38, SPIRV_Float> {
+ let summary = "Compute the base-2 logarithm of x.";
+
+ let description = [{
+ Result Type and x must be floating-point or vector(2,3,4,8,16) of
+ floating-point values.
+
+ All of the operands, including the Result Type operand, must be of the
+ same type.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.log2 %0 : f32
+ %3 = spirv.CL.log2 %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_CLLog10Op : SPIRV_CLUnaryArithmeticOp<"log10", 39, SPIRV_Float> {
+ let summary = "Compute the base-10 logarithm of x.";
+
+ let description = [{
+ Result Type and x must be floating-point or vector(2,3,4,8,16) of
+ floating-point values.
+
+ All of the operands, including the Result Type operand, must be of the
+ same type.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.log10 %0 : f32
+ %3 = spirv.CL.log10 %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_CLMixOp : SPIRV_CLTernaryArithmeticOp<"mix", 99, SPIRV_Float> {
let summary = "Returns the linear blend of x & y implemented as: x + (y - x) * a";
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index ea6be76373573..46face49febd8 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -542,8 +542,6 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
// OpenCL patterns
patterns.add<
Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
- Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
- Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
CheckedElementwiseOpPattern<math::CountLeadingZerosOp, spirv::CLClzOp>,
@@ -553,9 +551,12 @@ void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
+ CheckedElementwiseOpPattern<math::Exp2Op, spirv::CLExp2Op>,
CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
+ CheckedElementwiseOpPattern<math::Log2Op, spirv::CLLog2Op>,
+ CheckedElementwiseOpPattern<math::Log10Op, spirv::CLLog10Op>,
CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
index dae1b43402718..fb97edf146c14 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir
@@ -20,14 +20,12 @@ func.func @float32_unary_scalar(%arg0: f32) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
%5 = math.log1p %arg0 : f32
- // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant 1.44269502 : f32
- // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
- // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+ // CHECK: spirv.CL.log2 %{{.*}}: f32
%6 = math.log2 %arg0 : f32
- // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant 0.434294492 : f32
- // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
- // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+ // CHECK: spirv.CL.log10 %{{.*}}: f32
%7 = math.log10 %arg0 : f32
+ // CHECK: spirv.CL.exp2 %{{.*}}: f32
+ %exp2_scalar = math.exp2 %arg0 : f32
// CHECK: spirv.CL.rint %{{.*}}: f32
%8 = math.roundeven %arg0 : f32
// CHECK: spirv.CL.rsqrt %{{.*}}: f32
@@ -85,14 +83,12 @@ func.func @float32_unary_vector(%arg0: vector<3xf32>) {
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
// CHECK: spirv.CL.log %[[ADDONE]]
%5 = math.log1p %arg0 : vector<3xf32>
- // CHECK: %[[LOG2_RECIPROCAL:.+]] = spirv.Constant dense<1.44269502> : vector<3xf32>
- // CHECK: %[[LOG0:.+]] = spirv.CL.log {{.+}}
- // CHECK: spirv.FMul %[[LOG0]], %[[LOG2_RECIPROCAL]]
+ // CHECK: spirv.CL.log2 %{{.*}}: vector<3xf32>
%6 = math.log2 %arg0 : vector<3xf32>
- // CHECK: %[[LOG10_RECIPROCAL:.+]] = spirv.Constant dense<0.434294492> : vector<3xf32>
- // CHECK: %[[LOG1:.+]] = spirv.CL.log {{.+}}
- // CHECK: spirv.FMul %[[LOG1]], %[[LOG10_RECIPROCAL]]
+ // CHECK: spirv.CL.log10 %{{.*}}: vector<3xf32>
%7 = math.log10 %arg0 : vector<3xf32>
+ // CHECK: spirv.CL.exp2 %{{.*}}: vector<3xf32>
+ %exp2_vec = math.exp2 %arg0 : vector<3xf32>
// CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
%8 = math.roundeven %arg0 : vector<3xf32>
// CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index d6d3c53f23356..3b1ddf2494ca0 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -50,6 +50,78 @@ func.func @exp(%arg0 : i32) -> () {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.CL.exp2
+//===----------------------------------------------------------------------===//
+
+func.func @exp2(%arg0 : f32) -> () {
+ // CHECK: spirv.CL.exp2 {{%.*}} : f32
+ %2 = spirv.CL.exp2 %arg0 : f32
+ return
+}
+
+func.func @exp2vec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spirv.CL.exp2 {{%.*}} : vector<3xf16>
+ %2 = spirv.CL.exp2 %arg0 : vector<3xf16>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.exp10
+//===----------------------------------------------------------------------===//
+
+func.func @exp10(%arg0 : f32) -> () {
+ // CHECK: spirv.CL.exp10 {{%.*}} : f32
+ %2 = spirv.CL.exp10 %arg0 : f32
+ return
+}
+
+func.func @exp10vec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spirv.CL.exp10 {{%.*}} : vector<3xf16>
+ %2 = spirv.CL.exp10 %arg0 : vector<3xf16>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.log2
+//===----------------------------------------------------------------------===//
+
+func.func @log2(%arg0 : f32) -> () {
+ // CHECK: spirv.CL.log2 {{%.*}} : f32
+ %2 = spirv.CL.log2 %arg0 : f32
+ return
+}
+
+func.func @log2vec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spirv.CL.log2 {{%.*}} : vector<3xf16>
+ %2 = spirv.CL.log2 %arg0 : vector<3xf16>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.log10
+//===----------------------------------------------------------------------===//
+
+func.func @log10(%arg0 : f32) -> () {
+ // CHECK: spirv.CL.log10 {{%.*}} : f32
+ %2 = spirv.CL.log10 %arg0 : f32
+ return
+}
+
+func.func @log10vec(%arg0 : vector<3xf16>) -> () {
+ // CHECK: spirv.CL.log10 {{%.*}} : vector<3xf16>
+ %2 = spirv.CL.log10 %arg0 : vector<3xf16>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.CL.fabs
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 7a4abbd9dd344..e43223e65db5c 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -17,6 +17,14 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Vec
%3 = spirv.CL.cos %arg0 : f32
// CHECK: {{%.*}} = spirv.CL.log {{%.*}} : f32
%4 = spirv.CL.log %arg0 : f32
+ // CHECK: {{%.*}} = spirv.CL.exp2 {{%.*}} : f32
+ %exp2 = spirv.CL.exp2 %arg0 : f32
+ // CHECK: {{%.*}} = spirv.CL.exp10 {{%.*}} : f32
+ %exp10 = spirv.CL.exp10 %arg0 : f32
+ // CHECK: {{%.*}} = spirv.CL.log2 {{%.*}} : f32
+ %log2 = spirv.CL.log2 %arg0 : f32
+ // CHECK: {{%.*}} = spirv.CL.log10 {{%.*}} : f32
+ %log10 = spirv.CL.log10 %arg0 : f32
// CHECK: {{%.*}} = spirv.CL.sqrt {{%.*}} : f32
%5 = spirv.CL.sqrt %arg0 : f32
// CHECK: {{%.*}} = spirv.CL.ceil {{%.*}} : f32
More information about the Mlir-commits
mailing list