[Mlir-commits] [mlir] [MLIR][SPIRV] Lower SPIR-V Tan/Tanh ops to LLVM intrinsics (PR #168419)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 17 10:19:46 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Hank (hankluo6)

<details>
<summary>Changes</summary>

Fixed #<!-- -->148354

Lower SPIR-V Tan/Tanh ops using the corresponding LLVM intrinsics to reduce instructions and prevent overflow caused by the previous `exp`-based expansion.

---
Full diff: https://github.com/llvm/llvm-project/pull/168419.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+4-22) 
- (modified) mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir (+2-10) 


``````````diff
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 50fca564b5b64..02b61bd989368 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1520,20 +1520,12 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
     if (!dstType)
       return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
 
-    Location loc = tanOp.getLoc();
-    Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand());
-    Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand());
-    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
+    rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
+                                             adaptor.getOperands());
     return success();
   }
 };
 
-/// Convert `spirv.Tanh` to
-///
-///   exp(2x) - 1
-///   -----------
-///   exp(2x) + 1
-///
 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
 public:
   using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
@@ -1546,18 +1538,8 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
     if (!dstType)
       return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
 
-    Location loc = tanhOp.getLoc();
-    Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
-    Value multiplied =
-        LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand());
-    Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied);
-    Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
-    Value numerator =
-        LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one);
-    Value denominator =
-        LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one);
-    rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
-                                              denominator);
+    rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
+                                              adaptor.getOperands());
     return success();
   }
 };
diff --git a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
index e1936e2fd8abe..b17e1c40cb9a7 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir
@@ -162,9 +162,7 @@ spirv.func @sqrt(%arg0: f32, %arg1: vector<3xf16>) "None" {
 
 // CHECK-LABEL: @tan
 spirv.func @tan(%arg0: f32) "None" {
-  // CHECK: %[[SIN:.*]] = llvm.intr.sin(%{{.*}}) : (f32) -> f32
-  // CHECK: %[[COS:.*]] = llvm.intr.cos(%{{.*}}) : (f32) -> f32
-  // CHECK: llvm.fdiv %[[SIN]], %[[COS]] : f32
+  // CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32
   %0 = spirv.GL.Tan %arg0 : f32
   spirv.Return
 }
@@ -175,13 +173,7 @@ spirv.func @tan(%arg0: f32) "None" {
 
 // CHECK-LABEL: @tanh
 spirv.func @tanh(%arg0: f32) "None" {
-  // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : f32
-  // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : f32
-  // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[X2]]) : (f32) -> f32
-  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
-  // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32
-  // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : f32
-  // CHECK: llvm.fdiv %[[T0]], %[[T1]] : f32
+  // CHECK: llvm.intr.tanh(%{{.*}}) : (f32) -> f32
   %0 = spirv.GL.Tanh %arg0 : f32
   spirv.Return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/168419


More information about the Mlir-commits mailing list