[Mlir-commits] [mlir] [mlir][spirv] Add common SPIRV Extended Ops for Vectors (PR #122322)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 9 09:34:13 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (mishaobu)

<details>
<summary>Changes</summary>

(Found here: https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html) 

---
Full diff: https://github.com/llvm/llvm-project/pull/122322.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td (+108) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+36) 
- (modified) mlir/test/Dialect/SPIRV/IR/gl-ops.mlir (+114) 
- (modified) mlir/test/Target/SPIRV/gl-ops.mlir (+20) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index 3fcfb086f9662c..c99e8a506f1dba 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -1029,6 +1029,114 @@ def SPIRV_GLFMixOp :
   let hasVerifier = 0;
 }
 
+// -----
+
+def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [Pure]> {
+  let summary = "Return distance between two points";
+
+  let description = [{
+    Result is the distance between p0 and p1, i.e., length(p0 - p1).
+
+    The operands must all be a scalar or vector whose component type is floating-point.
+
+    Result Type must be a scalar of the same type as the component type of the operands.
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.GL.Distance %0, %1 : vector<3xf32>, vector<3xf32> -> f32
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$p0,
+    SPIRV_ScalarOrVectorOf<SPIRV_Float>:$p1
+  );
+
+  let results = (outs
+    SPIRV_Float:$result
+  );
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type($p0) `,` type($p1) `->` type($result)
+  }];
+}
+
+// -----
+
+def SPIRV_GLCrossOp : SPIRV_GLBinaryArithmeticOp<"Cross", 68, SPIRV_Float> {
+  let summary = "Return the cross product of two 3-component vectors";
+
+  let description = [{
+    Result is the cross product of x and y, i.e., the resulting components are, in order:
+
+    x[1] * y[2] - y[1] * x[2]
+
+    x[2] * y[0] - y[2] * x[0]
+
+    x[0] * y[1] - y[0] * x[1]
+
+    All the operands must be vectors of 3 components of a floating-point type.
+
+    Result Type and the type of all operands must be the same type.
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.GL.Cross %0, %1 : vector<3xf32>
+    %3 = spirv.GL.Cross %0, %1 : vector<3xf16>
+    ```
+  }];
+}
+
+// -----
+
+def SPIRV_GLNormalizeOp : SPIRV_GLUnaryArithmeticOp<"Normalize", 69, SPIRV_Float> {
+  let summary = "Normalizes a vector operand";
+
+  let description = [{
+    Result is the vector in the same direction as x but with a length of 1.
+
+    The operand x must be a scalar or vector whose component type is floating-point.
+
+    Result Type and the type of x must be the same type.
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.GL.Normalize %0 : vector<3xf32>
+    %3 = spirv.GL.Normalize %1 : vector<4xf16>
+    ```
+  }];
+}
+
+// -----
+
+def SPIRV_GLReflectOp : SPIRV_GLBinaryArithmeticOp<"Reflect", 71, SPIRV_Float> {
+  let summary = "Calculate reflection direction vector";
+
+  let description = [{
+    For the incident vector I and surface orientation N, the result is the reflection direction:
+
+    I - 2 * dot(N, I) * N
+
+    N must already be normalized in order to achieve the desired result.
+
+    The operands must all be a scalar or vector whose component type is floating-point.
+
+    Result Type and the type of all operands must be the same type.
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.GL.Reflect %0, %1 : f32
+    %3 = spirv.GL.Reflect %0, %1 : vector<3xf32>
+    ```
+  }];
+}
+
+// ----
+
 def SPIRV_GLFindUMsbOp : SPIRV_GLUnaryArithmeticOp<"FindUMsb", 75, SPIRV_Int32> {
   let summary = "Unsigned-integer most-significant bit";
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..b789ead75f5092 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2094,3 +2094,39 @@ LogicalResult spirv::VectorTimesScalarOp::verify() {
     return emitOpError("scalar operand and result element type match");
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// spirv.GLDistanceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GLDistanceOp::verify() {
+  auto p0Type = getP0().getType();
+  auto p1Type = getP1().getType();
+  auto resultType = getResult().getType();
+
+  auto getFloatType = [](Type type) -> FloatType {
+    if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+      return llvm::dyn_cast<FloatType>(vectorType.getElementType());
+    return llvm::dyn_cast<FloatType>(type);
+  };
+
+  FloatType p0FloatType = getFloatType(p0Type);
+  FloatType p1FloatType = getFloatType(p1Type);
+  FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);
+
+  if (!p0FloatType || !p1FloatType || !resultFloatType)
+    return emitOpError("operands and result must be float scalar or vector of float"); 
+
+  if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
+    return emitOpError("operand and result element types must match");
+
+  if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
+    if (!llvm::dyn_cast<VectorType>(p1Type) || 
+        p0Vec.getShape() != llvm::dyn_cast<VectorType>(p1Type).getShape())
+      return emitOpError("vector operands must have same shape");
+  } else if (llvm::isa<VectorType>(p1Type)) {
+    return emitOpError("expected both operands to be scalars or both to be vectors");
+  }
+
+  return success();
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 3683e5b469b17b..0daa7aca81d296 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -541,3 +541,117 @@ func.func @findumsb(%arg0 : i64) -> () {
   %2 = spirv.GL.FindUMsb %arg0 : i64
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Distance 
+//===----------------------------------------------------------------------===//
+
+func.func @distance_scalar(%arg0 : f32, %arg1 : f32) {
+  // CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32
+  %0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> f32
+  return
+}
+
+func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
+  // CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32
+  %0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<3xf32> -> f32
+  return
+}
+
+// -----
+
+func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) {
+  // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  %0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32
+  return
+}
+
+// -----
+
+func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) {
+  // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  %0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32
+  return
+}
+
+// -----
+
+func.func @distance_invalid_result(%arg0 : f32, %arg1 : f32) {
+  // expected-error @+1 {{'spirv.GL.Distance' op result #0 must be 16/32/64-bit float}}
+  %0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> i32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Cross
+//===----------------------------------------------------------------------===//
+
+func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
+  %2 = spirv.GL.Cross %arg0, %arg1 : vector<3xf32>
+  // CHECK: %{{.+}} = spirv.GL.Cross %{{.+}}, %{{.+}} : vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) {
+  // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+  %0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Normalize
+//===----------------------------------------------------------------------===//
+
+func.func @normalize_scalar(%arg0 : f32) {
+  %2 = spirv.GL.Normalize %arg0 : f32
+  // CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : f32
+  return
+}
+
+func.func @normalize_vector(%arg0 : vector<3xf32>) {
+  %2 = spirv.GL.Normalize %arg0 : vector<3xf32>
+  // CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @normalize_invalid_type(%arg0 : i32) {
+  // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.GL.Normalize %arg0 : i32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Reflect
+//===----------------------------------------------------------------------===//
+
+func.func @reflect_scalar(%arg0 : f32, %arg1 : f32) {
+  %2 = spirv.GL.Reflect %arg0, %arg1 : f32
+  // CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : f32
+  return
+}
+
+func.func @reflect_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
+  %2 = spirv.GL.Reflect %arg0, %arg1 : vector<3xf32>
+  // CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) {
+  // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+  %0 = spirv.GL.Reflect %arg0, %arg1 : i32
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Target/SPIRV/gl-ops.mlir b/mlir/test/Target/SPIRV/gl-ops.mlir
index fff1adf0ae12c6..119304cea7d4ad 100644
--- a/mlir/test/Target/SPIRV/gl-ops.mlir
+++ b/mlir/test/Target/SPIRV/gl-ops.mlir
@@ -81,4 +81,24 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %2 = spirv.GL.FindUMsb %arg0 : i32
     spirv.Return
   }
+
+spirv.func @vector(%arg0 : f32, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.GL.Cross {{%.*}}, {{%.*}} : vector<3xf32>
+    %0 = spirv.GL.Cross %arg1, %arg2 : vector<3xf32>
+    // CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : f32
+    %1 = spirv.GL.Normalize %arg0 : f32
+    // CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : vector<3xf32>
+    %2 = spirv.GL.Normalize %arg1 : vector<3xf32>
+    // CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : f32
+    %3 = spirv.GL.Reflect %arg0, %arg0 : f32
+    // CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : vector<3xf32>
+    %4 = spirv.GL.Reflect %arg1, %arg2 : vector<3xf32>
+    // CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32
+    %5 = spirv.GL.Distance %arg0, %arg0 : f32, f32 -> f32
+    // CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32
+    %6 = spirv.GL.Distance %arg1, %arg2 : vector<3xf32>, vector<3xf32> -> f32
+    spirv.Return
+  }
+
+
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/122322


More information about the Mlir-commits mailing list