[Mlir-commits] [mlir] [mlir][SPIR-V] Add OpenCL.std ldexp, pown, and rootn ops (PR #194791)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Tue Apr 28 22:46:53 PDT 2026
https://github.com/aobolensk created https://github.com/llvm/llvm-project/pull/194791
Add operations that follow `float op(float, int)` pattern, mirroring the existing `spirv.GL.Ldexp` op
>From 0feabd5a8f0a48619059f7bacd4df1f5d0ee16ae Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 29 Apr 2026 07:44:41 +0200
Subject: [PATCH] [mlir][SPIR-V] Add OpenCL.std ldexp, pown, and rootn ops
Add operations that follow `float op(float, int)` pattern, mirroring the existing spirv.GL.Ldexp op
---
.../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td | 120 ++++++++++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 46 +++++--
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 115 +++++++++++++++++
mlir/test/Target/SPIRV/ocl-ops.mlir | 20 +++
4 files changed, 293 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 5d086325fa5b1..5f77129b81fa7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -505,6 +505,48 @@ def SPIRV_CLFmaOp : SPIRV_CLTernaryArithmeticOp<"fma", 26, SPIRV_Float> {
// -----
+def SPIRV_CLLdexpOp : SPIRV_CLOp<"ldexp", 34, [Pure]> {
+ let summary = "Builds y such that y = significand * 2^exponent.";
+
+ let description = [{
+ Builds a floating-point number from x and the corresponding
+ integral exponent of two in n:
+
+ significand * 2^exponent
+
+ The operand x must be a scalar or vector whose component type is
+ floating-point.
+
+ The exp operand must be a scalar or vector with integer component type.
+ The number of components in x and exp must be the same.
+
+ Result Type must be the same type as the type of x. Results are computed
+ per component.
+
+ #### Example:
+
+ ```mlir
+ %y = spirv.CL.ldexp %x : f32, %exp : i32 -> f32
+ %y = spirv.CL.ldexp %x : vector<3xf32>, %exp : vector<3xi32> -> vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$exp
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$y
+ );
+
+ let assemblyFormat = [{
+ attr-dict $x `:` type($x) `,` $exp `:` type($exp) `->` type($y)
+ }];
+}
+
+// -----
+
def SPIRV_CLLogOp : SPIRV_CLUnaryArithmeticOp<"log", 37, SPIRV_Float> {
let summary = "Compute the natural logarithm of x.";
@@ -571,6 +613,45 @@ def SPIRV_CLPowOp : SPIRV_CLBinaryArithmeticOp<"pow", 48, SPIRV_Float> {
// -----
+def SPIRV_CLPownOp : SPIRV_CLOp<"pown", 49, [Pure]> {
+ let summary = "Compute x to the power y, where y is an integer.";
+
+ let description = [{
+ Result is x raised to the power y, where y is an integer.
+
+ The operand x must be a scalar or vector whose component type is
+ floating-point.
+
+ The y operand must be a scalar or vector with integer component type.
+ The number of components in x and y must be the same.
+
+ Result Type must be the same type as the type of x. Results are computed
+ per component.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.pown %0 : f32, %1 : i32 -> f32
+ %2 = spirv.CL.pown %0 : vector<3xf32>, %1 : vector<3xi32> -> vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$y
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict $x `:` type($x) `,` $y `:` type($y) `->` type($result)
+ }];
+}
+
+// -----
+
def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
let summary = [{
Round x to integral value (using round to nearest even rounding mode) in
@@ -595,6 +676,45 @@ def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
// -----
+def SPIRV_CLRootnOp : SPIRV_CLOp<"rootn", 54, [Pure]> {
+ let summary = "Compute the n-th root of x, where n is an integer.";
+
+ let description = [{
+ Result is the n-th root of x, where n is an integer.
+
+ The operand x must be a scalar or vector whose component type is
+ floating-point.
+
+ The n operand must be a scalar or vector with integer component type.
+ The number of components in x and n must be the same.
+
+ Result Type must be the same type as the type of x. Results are computed
+ per component.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.rootn %0 : f32, %1 : i32 -> f32
+ %2 = spirv.CL.rootn %0 : vector<3xf32>, %1 : vector<3xi32> -> vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$n
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict $x `:` type($x) `,` $n `:` type($n) `->` type($result)
+ }];
+}
+
+// -----
+
def SPIRV_CLRoundOp : SPIRV_CLUnaryArithmeticOp<"round", 55, SPIRV_Float> {
let summary = [{
Return the integral value nearest to x rounding halfway cases away from
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 49a14b5b30f0f..f8aa5bfeba452 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2049,12 +2049,10 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
// spirv.GL.Ldexp
//===----------------------------------------------------------------------===//
-LogicalResult spirv::GLLdexpOp::verify() {
- Type significandType = getX().getType();
- Type exponentType = getExp().getType();
-
- if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
- return emitOpError("operands must both be scalars or vectors");
+static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType,
+ Type integerType) {
+ if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
+ return op->emitOpError("operands must both be scalars or vectors");
auto getNumElements = [](Type type) -> unsigned {
if (auto vectorType = dyn_cast<VectorType>(type))
@@ -2062,12 +2060,44 @@ LogicalResult spirv::GLLdexpOp::verify() {
return 1;
};
- if (getNumElements(significandType) != getNumElements(exponentType))
- return emitOpError("operands must have the same number of elements");
+ if (getNumElements(floatType) != getNumElements(integerType))
+ return op->emitOpError("operands must have the same number of elements");
return success();
}
+LogicalResult spirv::GLLdexpOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getExp().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.ldexp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLLdexpOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getExp().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.pown
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLPownOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getY().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rootn
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLRootnOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getN().getType());
+}
+
//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogicalOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 6aaaa6012fefe..138dd485dc711 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -440,3 +440,118 @@ func.func @atan2(%arg0 : vector<4xf16>, %arg1 : vector<4xf16>) -> () {
return
}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.ldexp
+//===----------------------------------------------------------------------===//
+
+func.func @ldexp(%arg0 : f32, %arg1 : i32) -> () {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : f32, {{%.*}} : i32 -> f32
+ %0 = spirv.CL.ldexp %arg0 : f32, %arg1 : i32 -> f32
+ return
+}
+
+// -----
+
+func.func @ldexp_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.ldexp %arg0 : f32, %arg1 : vector<2xi32> -> f32
+ return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_vec_1(%arg0 : vector<3xf32>, %arg1 : i32) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : i32 -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_vec_2(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must have the same number of elements}}
+ %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : vector<2xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.pown
+//===----------------------------------------------------------------------===//
+
+func.func @pown(%arg0 : f32, %arg1 : i32) -> () {
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : f32, {{%.*}} : i32 -> f32
+ %0 = spirv.CL.pown %arg0 : f32, %arg1 : i32 -> f32
+ return
+}
+
+// -----
+
+func.func @pown_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.pown %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @pown_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.pown %arg0 : f32, %arg1 : vector<2xi32> -> f32
+ return
+}
+
+// -----
+
+func.func @pown_wrong_type_vec(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must have the same number of elements}}
+ %0 = spirv.CL.pown %arg0 : vector<3xf32>, %arg1 : vector<2xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rootn
+//===----------------------------------------------------------------------===//
+
+func.func @rootn(%arg0 : f32, %arg1 : i32) -> () {
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : f32, {{%.*}} : i32 -> f32
+ %0 = spirv.CL.rootn %arg0 : f32, %arg1 : i32 -> f32
+ return
+}
+
+// -----
+
+func.func @rootn_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.rootn %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @rootn_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.rootn %arg0 : f32, %arg1 : vector<2xi32> -> f32
+ return
+}
+
+// -----
+
+func.func @rootn_wrong_type_vec(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must have the same number of elements}}
+ %0 = spirv.CL.rootn %arg0 : vector<3xf32>, %arg1 : vector<2xi32> -> vector<3xf32>
+ return
+}
diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 17accd93e8249..74c451720be5e 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -50,6 +50,26 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Vec
spirv.Return
}
+ spirv.func @float_int_insts(%arg0 : f32, %arg1 : i32) "None" {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : f32, {{%.*}} : i32 -> f32
+ %0 = spirv.CL.ldexp %arg0 : f32, %arg1 : i32 -> f32
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : f32, {{%.*}} : i32 -> f32
+ %1 = spirv.CL.pown %arg0 : f32, %arg1 : i32 -> f32
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : f32, {{%.*}} : i32 -> f32
+ %2 = spirv.CL.rootn %arg0 : f32, %arg1 : i32 -> f32
+ spirv.Return
+ }
+
+ spirv.func @float_int_vec_insts(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.ldexp %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+ %1 = spirv.CL.pown %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xi32> -> vector<3xf32>
+ %2 = spirv.CL.rootn %arg0 : vector<3xf32>, %arg1 : vector<3xi32> -> vector<3xf32>
+ spirv.Return
+ }
+
spirv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" {
// CHECK: {{%.*}} = spirv.CL.fmax {{%.*}}, {{%.*}} : f32
%1 = spirv.CL.fmax %arg0, %arg1 : f32
More information about the Mlir-commits
mailing list