[Mlir-commits] [mlir] fe85dc9 - [mlir][SPIR-V] Add OpenCL.std ldexp, pown, and rootn ops (#194791)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 29 08:30:56 PDT 2026
Author: Arseniy Obolenskiy
Date: 2026-04-29T17:30:50+02:00
New Revision: fe85dc92a9f2121ba31cb9f639222fbd5a55f7e1
URL: https://github.com/llvm/llvm-project/commit/fe85dc92a9f2121ba31cb9f639222fbd5a55f7e1
DIFF: https://github.com/llvm/llvm-project/commit/fe85dc92a9f2121ba31cb9f639222fbd5a55f7e1.diff
LOG: [mlir][SPIR-V] Add OpenCL.std ldexp, pown, and rootn ops (#194791)
Add operations that follow `float op(float, int)` pattern, mirroring the
existing `spirv.GL.Ldexp` op
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
mlir/test/Target/SPIRV/ocl-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index 5d086325fa5b1..d36245d5ad6b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -505,6 +505,48 @@ def SPIRV_CLFmaOp : SPIRV_CLTernaryArithmeticOp<"fma", 26, SPIRV_Float> {
// -----
+def SPIRV_CLLdexpOp : SPIRV_CLOp<"ldexp", 34, [Pure, AllTypesMatch<["x", "y"]>]> {
+ let summary = "Builds y such that y = significand * 2^exponent.";
+
+ let description = [{
+ Builds a floating-point number from x and the corresponding
+ integral exponent of two in exp:
+
+ significand * 2^exponent
+
+ The operand x must be a scalar or vector whose component type is
+ floating-point.
+
+ The exp operand must be a scalar or vector with integer component type.
+ The number of components in x and exp must be the same.
+
+ Result Type must be the same type as the type of x. Results are computed
+ per component.
+
+ #### Example:
+
+ ```mlir
+ %y = spirv.CL.ldexp %x, %exp : f32, i32 -> f32
+ %y = spirv.CL.ldexp %x, %exp : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$exp
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$y
+ );
+
+ let assemblyFormat = [{
+ attr-dict $x `,` $exp `:` type($x) `,` type($exp) `->` type($y)
+ }];
+}
+
+// -----
+
def SPIRV_CLLogOp : SPIRV_CLUnaryArithmeticOp<"log", 37, SPIRV_Float> {
let summary = "Compute the natural logarithm of x.";
@@ -571,6 +613,45 @@ def SPIRV_CLPowOp : SPIRV_CLBinaryArithmeticOp<"pow", 48, SPIRV_Float> {
// -----
+def SPIRV_CLPownOp : SPIRV_CLOp<"pown", 49, [Pure, AllTypesMatch<["x", "result"]>]> {
+ let summary = "Compute x to the power y, where y is an integer.";
+
+ let description = [{
+ Result is x raised to the power y, where y is an integer.
+
+ The operand x must be a scalar or vector whose component type is
+ floating-point.
+
+ The y operand must be a scalar or vector with integer component type.
+ The number of components in x and y must be the same.
+
+ Result Type must be the same type as the type of x. Results are computed
+ per component.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.pown %0, %1 : f32, i32 -> f32
+ %2 = spirv.CL.pown %0, %1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$y
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict $x `,` $y `:` type($x) `,` type($y) `->` type($result)
+ }];
+}
+
+// -----
+
def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
let summary = [{
Round x to integral value (using round to nearest even rounding mode) in
@@ -595,6 +676,45 @@ def SPIRV_CLRintOp : SPIRV_CLUnaryArithmeticOp<"rint", 53, SPIRV_Float> {
// -----
+def SPIRV_CLRootnOp : SPIRV_CLOp<"rootn", 54, [Pure, AllTypesMatch<["x", "result"]>]> {
+ let summary = "Compute the n-th root of x, where n is an integer.";
+
+ let description = [{
+ Result is the n-th root of x, where n is an integer.
+
+ The operand x must be a scalar or vector whose component type is
+ floating-point.
+
+ The n operand must be a scalar or vector with integer component type.
+ The number of components in x and n must be the same.
+
+ Result Type must be the same type as the type of x. Results are computed
+ per component.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.CL.rootn %0, %1 : f32, i32 -> f32
+ %2 = spirv.CL.rootn %0, %1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$x,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$n
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$result
+ );
+
+ let assemblyFormat = [{
+ attr-dict $x `,` $n `:` type($x) `,` type($n) `->` type($result)
+ }];
+}
+
+// -----
+
def SPIRV_CLRoundOp : SPIRV_CLUnaryArithmeticOp<"round", 55, SPIRV_Float> {
let summary = [{
Return the integral value nearest to x rounding halfway cases away from
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 49a14b5b30f0f..f8aa5bfeba452 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2049,12 +2049,10 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
// spirv.GL.Ldexp
//===----------------------------------------------------------------------===//
-LogicalResult spirv::GLLdexpOp::verify() {
- Type significandType = getX().getType();
- Type exponentType = getExp().getType();
-
- if (isa<FloatType>(significandType) != isa<IntegerType>(exponentType))
- return emitOpError("operands must both be scalars or vectors");
+static LogicalResult verifyFloatIntegerBuiltin(Operation *op, Type floatType,
+ Type integerType) {
+ if (isa<FloatType>(floatType) != isa<IntegerType>(integerType))
+ return op->emitOpError("operands must both be scalars or vectors");
auto getNumElements = [](Type type) -> unsigned {
if (auto vectorType = dyn_cast<VectorType>(type))
@@ -2062,12 +2060,44 @@ LogicalResult spirv::GLLdexpOp::verify() {
return 1;
};
- if (getNumElements(significandType) != getNumElements(exponentType))
- return emitOpError("operands must have the same number of elements");
+ if (getNumElements(floatType) != getNumElements(integerType))
+ return op->emitOpError("operands must have the same number of elements");
return success();
}
+LogicalResult spirv::GLLdexpOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getExp().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.ldexp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLLdexpOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getExp().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.pown
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLPownOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getY().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rootn
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::CLRootnOp::verify() {
+ return verifyFloatIntegerBuiltin(getOperation(), getX().getType(),
+ getN().getType());
+}
+
//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogicalOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 6aaaa6012fefe..8d81cb42030a5 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -440,3 +440,118 @@ func.func @atan2(%arg0 : vector<4xf16>, %arg1 : vector<4xf16>) -> () {
return
}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.ldexp
+//===----------------------------------------------------------------------===//
+
+func.func @ldexp(%arg0 : f32, %arg1 : i32) -> () {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}}, {{%.*}} : f32, i32 -> f32
+ %0 = spirv.CL.ldexp %arg0, %arg1 : f32, i32 -> f32
+ return
+}
+
+// -----
+
+func.func @ldexp_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.ldexp %arg0, %arg1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.ldexp %arg0, %arg1 : f32, vector<2xi32> -> f32
+ return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_vec_1(%arg0 : vector<3xf32>, %arg1 : i32) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.ldexp %arg0, %arg1 : vector<3xf32>, i32 -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @ldexp_wrong_type_vec_2(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must have the same number of elements}}
+ %0 = spirv.CL.ldexp %arg0, %arg1 : vector<3xf32>, vector<2xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.pown
+//===----------------------------------------------------------------------===//
+
+func.func @pown(%arg0 : f32, %arg1 : i32) -> () {
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}}, {{%.*}} : f32, i32 -> f32
+ %0 = spirv.CL.pown %arg0, %arg1 : f32, i32 -> f32
+ return
+}
+
+// -----
+
+func.func @pown_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.pown %arg0, %arg1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @pown_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.pown %arg0, %arg1 : f32, vector<2xi32> -> f32
+ return
+}
+
+// -----
+
+func.func @pown_wrong_type_vec(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must have the same number of elements}}
+ %0 = spirv.CL.pown %arg0, %arg1 : vector<3xf32>, vector<2xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.CL.rootn
+//===----------------------------------------------------------------------===//
+
+func.func @rootn(%arg0 : f32, %arg1 : i32) -> () {
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}}, {{%.*}} : f32, i32 -> f32
+ %0 = spirv.CL.rootn %arg0, %arg1 : f32, i32 -> f32
+ return
+}
+
+// -----
+
+func.func @rootn_vec(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) -> () {
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.rootn %arg0, %arg1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @rootn_wrong_type_scalar(%arg0 : f32, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must both be scalars or vectors}}
+ %0 = spirv.CL.rootn %arg0, %arg1 : f32, vector<2xi32> -> f32
+ return
+}
+
+// -----
+
+func.func @rootn_wrong_type_vec(%arg0 : vector<3xf32>, %arg1 : vector<2xi32>) -> () {
+ // expected-error @+1 {{operands must have the same number of elements}}
+ %0 = spirv.CL.rootn %arg0, %arg1 : vector<3xf32>, vector<2xi32> -> vector<3xf32>
+ return
+}
diff --git a/mlir/test/Target/SPIRV/ocl-ops.mlir b/mlir/test/Target/SPIRV/ocl-ops.mlir
index 17accd93e8249..14e45518502de 100644
--- a/mlir/test/Target/SPIRV/ocl-ops.mlir
+++ b/mlir/test/Target/SPIRV/ocl-ops.mlir
@@ -50,6 +50,26 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Addresses, Vec
spirv.Return
}
+ spirv.func @float_int_insts(%arg0 : f32, %arg1 : i32) "None" {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}}, {{%.*}} : f32, i32 -> f32
+ %0 = spirv.CL.ldexp %arg0, %arg1 : f32, i32 -> f32
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}}, {{%.*}} : f32, i32 -> f32
+ %1 = spirv.CL.pown %arg0, %arg1 : f32, i32 -> f32
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}}, {{%.*}} : f32, i32 -> f32
+ %2 = spirv.CL.rootn %arg0, %arg1 : f32, i32 -> f32
+ spirv.Return
+ }
+
+ spirv.func @float_int_vec_insts(%arg0 : vector<3xf32>, %arg1 : vector<3xi32>) "None" {
+ // CHECK: {{%.*}} = spirv.CL.ldexp {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ %0 = spirv.CL.ldexp %arg0, %arg1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ // CHECK: {{%.*}} = spirv.CL.pown {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ %1 = spirv.CL.pown %arg0, %arg1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ // CHECK: {{%.*}} = spirv.CL.rootn {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ %2 = spirv.CL.rootn %arg0, %arg1 : vector<3xf32>, vector<3xi32> -> vector<3xf32>
+ spirv.Return
+ }
+
spirv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" {
// CHECK: {{%.*}} = spirv.CL.fmax {{%.*}}, {{%.*}} : f32
%1 = spirv.CL.fmax %arg0, %arg1 : f32
More information about the Mlir-commits
mailing list