[Mlir-commits] [mlir] [mlir][SPIR-V] Add OpenCL.std ldexp, pown, and rootn ops (PR #194791)

Arseniy Obolenskiy llvmlistbot at llvm.org
Tue Apr 28 22:46:53 PDT 2026


https://github.com/aobolensk created https://github.com/llvm/llvm-project/pull/194791

Add operations that follow `float op(float, int)` pattern, mirroring the existing `spirv.GL.Ldexp` op

>From 0feabd5a8f0a48619059f7bacd4df1f5d0ee16ae Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 29 Apr 2026 07:44:41 +0200
Subject: [PATCH] [mlir][SPIR-V] Add OpenCL.std ldexp, pown, and rootn ops

Add operations that follow `float op(float, int)` pattern, mirroring the existing spirv.GL.Ldexp op
---
 .../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td       | 120 ++++++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  46 +++++--
 mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir       | 115 +++++++++++++++++
 mlir/test/Target/SPIRV/ocl-ops.mlir           |  20 +++
 4 files changed, 293 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 5d086325fa5b1..5f77129b81fa7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -505,6 +505,48 @@ def SPIRV_CLFmaOp : SPIRV_CLTernaryArithmeticOp<"fma", 26, SPIRV_Float> {
 
 // -----
 
+def SPIRV_CLLdexpOp : SPIRV_CLOp<"ldexp", 34, [Pure]> {
+  let summary = "Builds y such that y = significand * 2^exponent.";
+
+  let description = [{
+    Builds a floating-point number from x and the corresponding
+    integral exponent of two in n:
+
+    significand * 2^exponent
+
+    The operand x must be a scalar or vector whose component type is
+    floating-point.
+
+    The exp operand must be a scalar or vector with integer component type.
+    The number of components in x and exp must be the same.
+
+    Result Type must be the same type as the type of x. Results are computed
+    per component.
+
+    #### Example:
+
+    ```mlir
+    %y = spirv.CL.ldexp %x : f32, %exp : i32 -> f32
+    %y = spirv.CL.ldexp %x : vector<3xf32>, %exp : vector<3xi32> -> vector<3xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$exp
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$y
+  );
+
+  let assemblyFormat = [{
+    attr-dict $x `:` type($x) `,` $exp `:` type($exp) `->` type($y)
+  }];
+}
+
+// -----
+
 def SPIRV_CLLogOp : SPIRV_CLUnaryArithmeticOp<"log", 37, SPIRV_Float> {
   let summary = "Compute the natural logarithm of x.";
 
@@ -571,6 +613,45 @@ def SPIRV_CLPowOp : SPIRV_CLBinaryArithmeticOp<"pow", 48, SPIRV_Float> {
 
 // -----
 
+def SPIRV_CLPownOp : SPIRV_CLOp<"pown", 49, [Pure]> {
+  let summary = "Compute x to the power y, where y is an integer.";
+
+  let description = [{
+    Result is x raised to the power y, where y is an integer.
+
+    The operand x must be a scalar or vector whose component type is
+    floating-point.
+
+    The y operand must be a scalar or vector with integer component type.
+    The number of components in x and y must be the same.
+
+    Result Type must be the same type as the type of x. Results are computed
+    per component.
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.CL.pown %0 : f32, %1 : i32 -> f32
+    %2 = spirv.CL.pown %0 : vector<3xf32>, %1 : vector<3xi32> -> vector<3xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$y
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$result
+  );
+
+  let assemblyFormat = [{
+    attr-dict $x `:` type($x) `,` $y `:` type($y) `->` type($result)
+  }];
+}
+
+// -----
+
 def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
   let summary = [{
     Round x to integral value (using round to nearest even rounding mode) in
@@ -595,6 +676,45 @@ def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
 
 // -----
 
+def SPIRV_CLRootnOp : SPIRV_CLOp<"rootn", 54, [Pure]> {
+  let summary = "Compute the n-th root of x, where n is an integer.";
+
+  let description = [{
+    Result is the n-th root of x, where n is an integer.
+
+    The operand x must be a scalar or vector whose component type is
+    floating-point.
+
+    The n operand must be a scalar or vector with integer component type.
+    The number of components in x and n must be the same.
+
+    Result Type must be the same type as the type of x. Results are computed
+    per component.
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.CL.rootn %0 : f32, %1 : i32 -> f32
+    %2 = spirv.CL.rootn %0 : vector<3xf32>, %1 : vector<3xi32> -> vector<3xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$n
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$result
+  );
+
+  let assemblyFormat = [{
+    attr-dict $x `:` type($x) `,` $n `:` type($n) `->` type($result)
+  }];
+}
+
+// -----
+
 def SPIRV_CLRoundOp : SPIRV_CLUnaryArithmeticOp<"round", 55, SPIRV_Float> {
   let summary = [{
     Return the integral value nearest to x rounding halfway cases away from
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 49a14b5b30f0f..f8aa5bfeba452 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2049,12 +2049,10 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
 // spirv.GL.Ldexp
 //===----------------------------------------------------------------------===//
 
-LogicalResult spirv::GLLdexpOp::verify() {
-  Type significandType = getX().getType();
-  Type exponentType = getExp().getType();
-
-  if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
-    return emitOpError("operands must both be scalars or vectors");
+static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType,
+                                               Type integerType) {
+  if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
+    return op->emitOpError("operands must both be scalars or vectors");
 
   auto getNumElements = [](Type type) -> unsigned {
     if (auto vectorType = dyn_cast<VectorType>(type))
@@ -2062,12 +2060,44 @@ LogicalResult spirv::GLLdexpOp::verify() {
     return 1;
   };
 
-  if (getNumElements(significandType) != getNumElements(exponentType))
-    return emitOpError("operands must have the same number of elements");
+  if (getNumElements(floatType) != getNumElements(integerType))
+    return op->emitOpError("operands must have the same number of elements");
 
   return success();
 }
 
+LogicalResult spirv::GLLdexpOp::verify() {
+  return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+                                   getExp().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.ldexp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLLdexpOp::verify() {
+  return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+                                   getExp().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.pown
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLPownOp::verify() {
+  return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+                                   getY().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rootn
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLRootnOp::verify() {
+  return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+                                   getN().getType());
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.ShiftLeftLogicalOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 6aaaa6012fefe..138dd485dc711 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -440,3 +440,118 @@ func.func @atan2(%arg0 : vector<4xf16>, %arg1 : vector<4xf16>) -> () {
   return
 }
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.ldexp
+//===----------------------------------------------------------------------===//
+
+func.func @ldexp(%arg0 : f32, %arg1 : i32) -> () {
+  // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : f32, {{%.*}} : i32 -> f32
+  %0 = spirv.CL.ldexp %arg0 : f32, %arg1 : i32 -> f32
+  return
+}
+
+// -----
+
+func.func @ldexp_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+  // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+  %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+  // expected-error @+1 {{operands must both be scalars or vectors}}
+  %0 = spirv.CL.ldexp %arg0 : f32, %arg1 : vector<2xi32> -> f32
+  return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_vec_1(%arg0 : vector<3xf32>, %arg1 : i32) -> () {
+  // expected-error @+1 {{operands must both be scalars or vectors}}
+  %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : i32 -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_vec_2(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+  // expected-error @+1 {{operands must have the same number of elements}}
+  %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : vector<2xi32> -> vector<3xf32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.pown
+//===----------------------------------------------------------------------===//
+
+func.func @pown(%arg0 : f32, %arg1 : i32) -> () {
+  // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : f32, {{%.*}} : i32 -> f32
+  %0 = spirv.CL.pown %arg0 : f32, %arg1 : i32 -> f32
+  return
+}
+
+// -----
+
+func.func @pown_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+  // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+  %0 = spirv.CL.pown %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @pown_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+  // expected-error @+1 {{operands must both be scalars or vectors}}
+  %0 = spirv.CL.pown %arg0 : f32, %arg1 : vector<2xi32> -> f32
+  return
+}
+
+// -----
+
+func.func @pown_wrong_type_vec(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+  // expected-error @+1 {{operands must have the same number of elements}}
+  %0 = spirv.CL.pown %arg0 : vector<3xf32>, %arg1 : vector<2xi32> -> vector<3xf32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rootn
+//===----------------------------------------------------------------------===//
+
+func.func @rootn(%arg0 : f32, %arg1 : i32) -> () {
+  // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : f32, {{%.*}} : i32 -> f32
+  %0 = spirv.CL.rootn %arg0 : f32, %arg1 : i32 -> f32
+  return
+}
+
+// -----
+
+func.func @rootn_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+  // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+  %0 = spirv.CL.rootn %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @rootn_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+  // expected-error @+1 {{operands must both be scalars or vectors}}
+  %0 = spirv.CL.rootn %arg0 : f32, %arg1 : vector<2xi32> -> f32
+  return
+}
+
+// -----
+
+func.func @rootn_wrong_type_vec(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+  // expected-error @+1 {{operands must have the same number of elements}}
+  %0 = spirv.CL.rootn %arg0 : vector<3xf32>, %arg1 : vector<2xi32> -> vector<3xf32>
+  return
+}
diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 17accd93e8249..74c451720be5e 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -50,6 +50,26 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Vec
     spirv.Return
   }
 
+  spirv.func @float_int_insts(%arg0 : f32, %arg1 : i32) "None" {
+    // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : f32, {{%.*}} : i32 -> f32
+    %0 = spirv.CL.ldexp %arg0 : f32, %arg1 : i32 -> f32
+    // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : f32, {{%.*}} : i32 -> f32
+    %1 = spirv.CL.pown %arg0 : f32, %arg1 : i32 -> f32
+    // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : f32, {{%.*}} : i32 -> f32
+    %2 = spirv.CL.rootn %arg0 : f32, %arg1 : i32 -> f32
+    spirv.Return
+  }
+
+  spirv.func @float_int_vec_insts(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) "None" {
+    // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+    %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+    // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+    %1 = spirv.CL.pown %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+    // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+    %2 = spirv.CL.rootn %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+    spirv.Return
+  }
+
   spirv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" {
     // CHECK: {{%.*}} = spirv.CL.fmax {{%.*}}, {{%.*}} : f32
     %1 = spirv.CL.fmax %arg0, %arg1 : f32



More information about the Mlir-commits mailing list