[Mlir-commits] [mlir] b2bba5b - [mlir][spirv] Support conversion of `CopySignOp` to spirv for 1D vector with 1 element

Guray Ozen llvmlistbot at llvm.org
Thu Dec 8 00:11:34 PST 2022


Author: Guray Ozen
Date: 2022-12-08T09:11:27+01:00
New Revision: b2bba5b65c9f90f9c75da35fcedec08a01640d80

URL: https://github.com/llvm/llvm-project/commit/b2bba5b65c9f90f9c75da35fcedec08a01640d80
DIFF: https://github.com/llvm/llvm-project/commit/b2bba5b65c9f90f9c75da35fcedec08a01640d80.diff

LOG: [mlir][spirv] Support conversion of `CopySignOp` to spirv for 1D vector with 1 element

Conversion of CopySignOp to SPIRV is supported for scalar and vectors but not 1D vectors with 1 element (aka vector<1xf32>). This revisions adds supports this by treating them as scalars.

An alternative solution would be to allow 0D vectors for SPIRV, but the spec [0] strictly defines the vector type as non-0D.
"Vector: An ordered homogeneous collection of two or more scalars. Vector sizes are quite restrictive and dependent on the execution model."

[0] https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_types

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D139518

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 5bd06c947e49c..80b22576e61c8 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -151,7 +151,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     Value valueMask = rewriter.create<spirv::ConstantOp>(
         loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
 
-    if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
+    if (auto vectorType = type.dyn_cast<VectorType>()) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
       intType = VectorType::get(count, intType);

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
index e84b9b0f97717..f0119afa42f69 100644
--- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
@@ -65,3 +65,28 @@ func.func @copy_sign_tensor(%value: tensor<3x3xf32>, %sign: tensor<3x3xf32>) ->
 // CHECK-LABEL: func @copy_sign_tensor
 // CHECK-NEXT:    math.copysign {{%.+}}, {{%.+}} : tensor<3x3xf32>
 // CHECK-NEXT:    return
+// -----
+
+module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16, Int16], []>, #spirv.resource_limits<>> } {
+
+func.func @copy_sign_vector_0D(%value: vector<1xf16>, %sign: vector<1xf16>) -> vector<1xf16> {
+  %0 = math.copysign %value, %sign : vector<1xf16>
+  return %0: vector<1xf16>
+}
+
+}
+
+// CHECK-LABEL: func @copy_sign_vector_0D
+//  CHECK-SAME: (%[[VALUE:.+]]: vector<1xf16>, %[[SIGN:.+]]: vector<1xf16>)
+//       CHECK:   %[[CASTVAL:.+]] = builtin.unrealized_conversion_cast %[[VALUE]] : vector<1xf16> to f16
+//       CHECK:   %[[CASTSIGN:.+]] = builtin.unrealized_conversion_cast %[[SIGN]] : vector<1xf16> to f16
+//       CHECK:   %[[SMASK:.+]] = spirv.Constant -32768 : i16
+//       CHECK:   %[[VMASK:.+]] = spirv.Constant 32767 : i16
+//       CHECK:   %[[VCAST:.+]] = spirv.Bitcast %[[CASTVAL]] : f16 to i16
+//       CHECK:   %[[SCAST:.+]] = spirv.Bitcast %[[CASTSIGN]] : f16 to i16
+//       CHECK:   %[[VAND:.+]] = spirv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i16
+//       CHECK:   %[[SAND:.+]] = spirv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i16
+//       CHECK:   %[[OR:.+]] = spirv.BitwiseOr %[[VAND]], %[[SAND]] : i16
+//       CHECK:   %[[RESULT:.+]] = spirv.Bitcast %[[OR]] : i16 to f16
+//       CHECK:   %[[CASTRESULT:.+]] = builtin.unrealized_conversion_cast %[[RESULT]] : f16 to vector<1xf16>
+//       CHECK:   return %[[CASTRESULT]]


        


More information about the Mlir-commits mailing list