[Mlir-commits] [mlir] [mlir][spirv] Add common SPIRV Extended Ops for Vectors (PR #122322)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 17 10:02:09 PST 2025
https://github.com/mishaobu updated https://github.com/llvm/llvm-project/pull/122322
>From 02974a6e3411786a651c022116599dfc1e62f84c Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 16:55:37 +0100
Subject: [PATCH 01/10] initial
---
.../mlir/Dialect/SPIRV/IR/SPIRVGLOps.td | 102 ++++++++++++++++++
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 27 +++++
2 files changed, 129 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index 3fcfb086f9662c..ee3931759f28d2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -1029,6 +1029,108 @@ def SPIRV_GLFMixOp :
let hasVerifier = 0;
}
+// -----
+
+def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [Pure]> {
+ let summary = "Return distance between two points";
+
+ let description = [{
+ Result is the distance between x and y. This is |x - y|, where |x| is the
+ length of x computed as sqrt(x * x).
+
+ The operands must all be vectors whose component type is floating-point.
+ Result Type must be a scalar floating-point type.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.GL.Distance %0, %1 : vector<3xf32>, vector<3xf32> -> f32
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$p0,
+ SPIRV_ScalarOrVectorOf<SPIRV_Float>:$p1
+ );
+
+ let results = (outs
+ SPIRV_Float:$result
+ );
+
+ let assemblyFormat = [{
+ operands attr-dict `:` type($p0) `,` type($p1) `->` type($result)
+ }];
+}
+
+// -----
+
+def SPIRV_GLCrossOp : SPIRV_GLBinaryArithmeticOp<"Cross", 68, SPIRV_Float> {
+ let summary = "Return the cross product of two 3-component vectors";
+
+ let description = [{
+ Result is the cross product of x and y, i.e. dot(x,y).
+
+ Arguments x and y must be vectors of size 3, of floating-point type.
+ Results are computed per component.
+ The result is a vector of the same type as the inputs.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.GL.Cross %0, %1 : vector<3xf32>
+ %3 = spirv.GL.Cross %0, %1 : vector<3xf16>
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_GLNormalizeOp : SPIRV_GLUnaryArithmeticOp<"Normalize", 69, SPIRV_Float> {
+ let summary = "Normalizes a vector operand";
+
+ let description = [{
+ Result is the vector v/sqrt(dot(v,v)).
+ The result is undefined if dot(v,v) is less than or equal to 0.
+
+ The operand v must be a vector whose component type is floating-point.
+ The Result Type must be the same type as v.
+ Results are computed per component.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.GL.Normalize %0 : vector<3xf32>
+ %3 = spirv.GL.Normalize %1 : vector<4xf16>
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_GLReflectOp : SPIRV_GLBinaryArithmeticOp<"Reflect", 71, SPIRV_Float> {
+ let summary = "Calculate reflection direction vector";
+
+ let description = [{
+ Result is the reflection direction vector: I - 2 * dot(N, I) * N, where I is
+ the incident vector and N is the surface orientation vector.
+
+ N must be normalized to achieve the desired result.
+
+ The operands must all be a scalar or vector whose component type is floating-point.
+ Result Type and the type of all operands must be the same type.
+ Results are computed per component.
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.GL.Reflect %0, %1 : f32
+ %3 = spirv.GL.Reflect %0, %1 : vector<3xf32>
+ ```
+ }];
+}
+
+// ----
+
def SPIRV_GLFindUMsbOp : SPIRV_GLUnaryArithmeticOp<"FindUMsb", 75, SPIRV_Int32> {
let summary = "Unsigned-integer most-significant bit";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..1b75f9cf0f5fa6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2094,3 +2094,30 @@ LogicalResult spirv::VectorTimesScalarOp::verify() {
return emitOpError("scalar operand and result element type match");
return success();
}
+
+//===----------------------------------------------------------------------===//
+// spirv.GLDistanceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GLDistanceOp::verify() {
+ auto p0Type = getP0().getType();
+ auto p1Type = getP1().getType();
+ auto resultType = getResult().getType();
+
+ auto p0VectorType = p0Type.dyn_cast<VectorType>();
+ auto p1VectorType = p1Type.dyn_cast<VectorType>();
+ if (!p0VectorType || !p1VectorType)
+ return emitOpError("operands must be vectors");
+
+ if (p0VectorType.getShape() != p1VectorType.getShape())
+ return emitOpError("operands must have same shape");
+
+ if (!resultType.isa<FloatType>())
+ return emitOpError("result must be scalar float");
+
+ if (p0VectorType.getElementType() != resultType ||
+ p1VectorType.getElementType() != resultType)
+ return emitOpError("operand vector elements must match result type");
+
+ return success();
+}
\ No newline at end of file
>From 560862ba863445c60f0aa8acf76c65b120087436 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 17:15:14 +0100
Subject: [PATCH 02/10] fix descriptions
---
.../mlir/Dialect/SPIRV/IR/SPIRVGLOps.td | 42 +++++++++++--------
1 file changed, 24 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index ee3931759f28d2..c99e8a506f1dba 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -1035,11 +1035,11 @@ def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [Pure]> {
let summary = "Return distance between two points";
let description = [{
- Result is the distance between x and y. This is |x - y|, where |x| is the
- length of x computed as sqrt(x * x).
+ Result is the distance between p0 and p1, i.e., length(p0 - p1).
- The operands must all be vectors whose component type is floating-point.
- Result Type must be a scalar floating-point type.
+ The operands must all be a scalar or vector whose component type is floating-point.
+
+ Result Type must be a scalar of the same type as the component type of the operands.
#### Example:
@@ -1068,11 +1068,17 @@ def SPIRV_GLCrossOp : SPIRV_GLBinaryArithmeticOp<"Cross", 68, SPIRV_Float> {
let summary = "Return the cross product of two 3-component vectors";
let description = [{
- Result is the cross product of x and y, i.e. dot(x,y).
+ Result is the cross product of x and y, i.e., the resulting components are, in order:
+
+ x[1] * y[2] - y[1] * x[2]
+
+ x[2] * y[0] - y[2] * x[0]
+
+ x[0] * y[1] - y[0] * x[1]
- Arguments x and y must be vectors of size 3, of floating-point type.
- Results are computed per component.
- The result is a vector of the same type as the inputs.
+ All the operands must be vectors of 3 components of a floating-point type.
+
+ Result Type and the type of all operands must be the same type.
#### Example:
@@ -1089,12 +1095,11 @@ def SPIRV_GLNormalizeOp : SPIRV_GLUnaryArithmeticOp<"Normalize", 69, SPIRV_Float
let summary = "Normalizes a vector operand";
let description = [{
- Result is the vector v/sqrt(dot(v,v)).
- The result is undefined if dot(v,v) is less than or equal to 0.
+ Result is the vector in the same direction as x but with a length of 1.
+
+ The operand x must be a scalar or vector whose component type is floating-point.
- The operand v must be a vector whose component type is floating-point.
- The Result Type must be the same type as v.
- Results are computed per component.
+ Result Type and the type of x must be the same type.
#### Example:
@@ -1111,14 +1116,15 @@ def SPIRV_GLReflectOp : SPIRV_GLBinaryArithmeticOp<"Reflect", 71, SPIRV_Float> {
let summary = "Calculate reflection direction vector";
let description = [{
- Result is the reflection direction vector: I - 2 * dot(N, I) * N, where I is
- the incident vector and N is the surface orientation vector.
+ For the incident vector I and surface orientation N, the result is the reflection direction:
- N must be normalized to achieve the desired result.
+ I - 2 * dot(N, I) * N
+
+ N must already be normalized in order to achieve the desired result.
The operands must all be a scalar or vector whose component type is floating-point.
- Result Type and the type of all operands must be the same type.
- Results are computed per component.
+
+ Result Type and the type of all operands must be the same type.
#### Example:
>From bc626294f9fbcec5a5d4ee3c60fed49920b4c5de Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 17:59:54 +0100
Subject: [PATCH 03/10] distance op verify + distance op tests
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 31 ++++++++++++-------
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 42 ++++++++++++++++++++++++++
2 files changed, 62 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1b75f9cf0f5fa6..b789ead75f5092 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2104,20 +2104,29 @@ LogicalResult spirv::GLDistanceOp::verify() {
auto p1Type = getP1().getType();
auto resultType = getResult().getType();
- auto p0VectorType = p0Type.dyn_cast<VectorType>();
- auto p1VectorType = p1Type.dyn_cast<VectorType>();
- if (!p0VectorType || !p1VectorType)
- return emitOpError("operands must be vectors");
+ auto getFloatType = [](Type type) -> FloatType {
+ if (auto vectorType = llvm::dyn_cast<VectorType>(type))
+ return llvm::dyn_cast<FloatType>(vectorType.getElementType());
+ return llvm::dyn_cast<FloatType>(type);
+ };
- if (p0VectorType.getShape() != p1VectorType.getShape())
- return emitOpError("operands must have same shape");
+ FloatType p0FloatType = getFloatType(p0Type);
+ FloatType p1FloatType = getFloatType(p1Type);
+ FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);
- if (!resultType.isa<FloatType>())
- return emitOpError("result must be scalar float");
+ if (!p0FloatType || !p1FloatType || !resultFloatType)
+ return emitOpError("operands and result must be float scalar or vector of float");
- if (p0VectorType.getElementType() != resultType ||
- p1VectorType.getElementType() != resultType)
- return emitOpError("operand vector elements must match result type");
+ if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
+ return emitOpError("operand and result element types must match");
+
+ if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
+ if (!llvm::dyn_cast<VectorType>(p1Type) ||
+ p0Vec.getShape() != llvm::dyn_cast<VectorType>(p1Type).getShape())
+ return emitOpError("vector operands must have same shape");
+ } else if (llvm::isa<VectorType>(p1Type)) {
+ return emitOpError("expected both operands to be scalars or both to be vectors");
+ }
return success();
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 3683e5b469b17b..d897277ecd226e 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -541,3 +541,45 @@ func.func @findumsb(%arg0 : i64) -> () {
%2 = spirv.GL.FindUMsb %arg0 : i64
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Distance
+//===----------------------------------------------------------------------===//
+
+func.func @distance_scalar(%arg0 : f32, %arg1 : f32) {
+ // CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32
+ %0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> f32
+ return
+}
+
+func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
+ // CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32
+ %0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<3xf32> -> f32
+ return
+}
+
+// -----
+
+func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) {
+ // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ %0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32
+ return
+}
+
+// -----
+
+func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) {
+ // expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ %0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32
+ return
+}
+
+// -----
+
+func.func @distance_invalid_result(%arg0 : f32, %arg1 : f32) {
+ // expected-error @+1 {{'spirv.GL.Distance' op result #0 must be 16/32/64-bit float}}
+ %0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> i32
+ return
+}
>From a2887923efe837beef3a600c2e7eb2bdc0e36dab Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 18:28:37 +0100
Subject: [PATCH 04/10] add target tests
---
mlir/test/Target/SPIRV/gl-ops.mlir | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
diff --git a/mlir/test/Target/SPIRV/gl-ops.mlir b/mlir/test/Target/SPIRV/gl-ops.mlir
index fff1adf0ae12c6..119304cea7d4ad 100644
--- a/mlir/test/Target/SPIRV/gl-ops.mlir
+++ b/mlir/test/Target/SPIRV/gl-ops.mlir
@@ -81,4 +81,24 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%2 = spirv.GL.FindUMsb %arg0 : i32
spirv.Return
}
+
+spirv.func @vector(%arg0 : f32, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.GL.Cross {{%.*}}, {{%.*}} : vector<3xf32>
+ %0 = spirv.GL.Cross %arg1, %arg2 : vector<3xf32>
+ // CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : f32
+ %1 = spirv.GL.Normalize %arg0 : f32
+ // CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : vector<3xf32>
+ %2 = spirv.GL.Normalize %arg1 : vector<3xf32>
+ // CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : f32
+ %3 = spirv.GL.Reflect %arg0, %arg0 : f32
+ // CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : vector<3xf32>
+ %4 = spirv.GL.Reflect %arg1, %arg2 : vector<3xf32>
+ // CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32
+ %5 = spirv.GL.Distance %arg0, %arg0 : f32, f32 -> f32
+ // CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32
+ %6 = spirv.GL.Distance %arg1, %arg2 : vector<3xf32>, vector<3xf32> -> f32
+ spirv.Return
+ }
+
+
}
>From 52d61dd0be09b211cd4b2feb40f89a7f6ff82272 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 18:29:27 +0100
Subject: [PATCH 05/10] more dialect tests
---
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 72 ++++++++++++++++++++++++++
1 file changed, 72 insertions(+)
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index d897277ecd226e..0daa7aca81d296 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -583,3 +583,75 @@ func.func @distance_invalid_result(%arg0 : f32, %arg1 : f32) {
%0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> i32
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Cross
+//===----------------------------------------------------------------------===//
+
+func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
+ %2 = spirv.GL.Cross %arg0, %arg1 : vector<3xf32>
+ // CHECK: %{{.+}} = spirv.GL.Cross %{{.+}}, %{{.+}} : vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) {
+ // expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
+ %0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Normalize
+//===----------------------------------------------------------------------===//
+
+func.func @normalize_scalar(%arg0 : f32) {
+ %2 = spirv.GL.Normalize %arg0 : f32
+ // CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : f32
+ return
+}
+
+func.func @normalize_vector(%arg0 : vector<3xf32>) {
+ %2 = spirv.GL.Normalize %arg0 : vector<3xf32>
+ // CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @normalize_invalid_type(%arg0 : i32) {
+ // expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.GL.Normalize %arg0 : i32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.GL.Reflect
+//===----------------------------------------------------------------------===//
+
+func.func @reflect_scalar(%arg0 : f32, %arg1 : f32) {
+ %2 = spirv.GL.Reflect %arg0, %arg1 : f32
+ // CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : f32
+ return
+}
+
+func.func @reflect_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
+ %2 = spirv.GL.Reflect %arg0, %arg1 : vector<3xf32>
+ // CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) {
+ // expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
+ %0 = spirv.GL.Reflect %arg0, %arg1 : i32
+ return
+}
\ No newline at end of file
>From 28e4cbfeac6d34460f5d5ebf36ff5e4d844a487a Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 18:51:13 +0100
Subject: [PATCH 06/10] clang-format
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index b789ead75f5092..5305fb88e21761 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2115,17 +2115,19 @@ LogicalResult spirv::GLDistanceOp::verify() {
FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);
if (!p0FloatType || !p1FloatType || !resultFloatType)
- return emitOpError("operands and result must be float scalar or vector of float");
+ return emitOpError(
+ "operands and result must be float scalar or vector of float");
if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
return emitOpError("operand and result element types must match");
if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
- if (!llvm::dyn_cast<VectorType>(p1Type) ||
+ if (!llvm::dyn_cast<VectorType>(p1Type) ||
p0Vec.getShape() != llvm::dyn_cast<VectorType>(p1Type).getShape())
return emitOpError("vector operands must have same shape");
} else if (llvm::isa<VectorType>(p1Type)) {
- return emitOpError("expected both operands to be scalars or both to be vectors");
+ return emitOpError(
+ "expected both operands to be scalars or both to be vectors");
}
return success();
>From f0f74ac0097b627b2d1d07bf97c79016343723a8 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Thu, 9 Jan 2025 19:41:17 +0100
Subject: [PATCH 07/10] newline
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 2 +-
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5305fb88e21761..d9739e7d157b6e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2131,4 +2131,4 @@ LogicalResult spirv::GLDistanceOp::verify() {
}
return success();
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 0daa7aca81d296..24899307a77724 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -654,4 +654,4 @@ func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) {
// expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%0 = spirv.GL.Reflect %arg0, %arg1 : i32
return
-}
\ No newline at end of file
+}
>From e8460ff9fe906385dbe809d59a4ef5fcc8526ca6 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 15:22:50 +0100
Subject: [PATCH 08/10] move verify to tablegen
---
.../mlir/Dialect/SPIRV/IR/SPIRVGLOps.td | 10 ++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 38 -------------------
2 files changed, 9 insertions(+), 39 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index c99e8a506f1dba..ed470ebb80ddb3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -1031,7 +1031,13 @@ def SPIRV_GLFMixOp :
// -----
-def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [Pure]> {
+def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [
+ Pure,
+ AllElementTypesMatch<["p0", "p1"]>,
+ TypesMatchWith<"result type must match operand element type",
+ "p0", "result",
+ "::mlir::getElementTypeOrSelf($_self)">
+ ]> {
let summary = "Return distance between two points";
let description = [{
@@ -1060,6 +1066,8 @@ def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [Pure]> {
let assemblyFormat = [{
operands attr-dict `:` type($p0) `,` type($p1) `->` type($result)
}];
+
+ let hasVerifier = 0;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index d9739e7d157b6e..26559c1321db5e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2094,41 +2094,3 @@ LogicalResult spirv::VectorTimesScalarOp::verify() {
return emitOpError("scalar operand and result element type match");
return success();
}
-
-//===----------------------------------------------------------------------===//
-// spirv.GLDistanceOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GLDistanceOp::verify() {
- auto p0Type = getP0().getType();
- auto p1Type = getP1().getType();
- auto resultType = getResult().getType();
-
- auto getFloatType = [](Type type) -> FloatType {
- if (auto vectorType = llvm::dyn_cast<VectorType>(type))
- return llvm::dyn_cast<FloatType>(vectorType.getElementType());
- return llvm::dyn_cast<FloatType>(type);
- };
-
- FloatType p0FloatType = getFloatType(p0Type);
- FloatType p1FloatType = getFloatType(p1Type);
- FloatType resultFloatType = llvm::dyn_cast<FloatType>(resultType);
-
- if (!p0FloatType || !p1FloatType || !resultFloatType)
- return emitOpError(
- "operands and result must be float scalar or vector of float");
-
- if (p0FloatType != resultFloatType || p1FloatType != resultFloatType)
- return emitOpError("operand and result element types must match");
-
- if (auto p0Vec = llvm::dyn_cast<VectorType>(p0Type)) {
- if (!llvm::dyn_cast<VectorType>(p1Type) ||
- p0Vec.getShape() != llvm::dyn_cast<VectorType>(p1Type).getShape())
- return emitOpError("vector operands must have same shape");
- } else if (llvm::isa<VectorType>(p1Type)) {
- return emitOpError(
- "expected both operands to be scalars or both to be vectors");
- }
-
- return success();
-}
>From 736e5a888f9fb0ed485c077b390cde29cc77f062 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 19:01:08 +0100
Subject: [PATCH 09/10] Arg size mismatch test
---
mlir/test/Dialect/SPIRV/IR/gl-ops.mlir | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index 24899307a77724..beda3872bc8d2f 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -570,6 +570,14 @@ func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) {
// -----
+func.func @distance_arg_mismatch(%arg0 : vector<3xf32>, %arg1 : vector<4xf32>) {
+ // expected-error @+1 {{'spirv.GL.Distance' op failed to verify that all of {p0, p1} have same type}}
+ %0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<4xf32> -> f32
+ return
+}
+
+// -----
+
func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) {
// expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
%0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32
>From 03cd8f279d5e5c9ff8e18f17b273cfdb68e1a828 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 19:01:40 +0100
Subject: [PATCH 10/10] fix distance op typechecking
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index ed470ebb80ddb3..1cdfa02f817879 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -1033,7 +1033,7 @@ def SPIRV_GLFMixOp :
def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [
Pure,
- AllElementTypesMatch<["p0", "p1"]>,
+ AllTypesMatch<["p0", "p1"]>,
TypesMatchWith<"result type must match operand element type",
"p0", "result",
"::mlir::getElementTypeOrSelf($_self)">
More information about the Mlir-commits
mailing list