[Mlir-commits] [mlir] 97f3bb7 - [mlir][spirv] Support convert complex types

Lei Zhang llvmlistbot at llvm.org
Thu Mar 30 08:49:28 PDT 2023


Author: Lei Zhang
Date: 2023-03-30T08:48:31-07:00
New Revision: 97f3bb73a29a566e99e33ae4338c2c3d9957e561

URL: https://github.com/llvm/llvm-project/commit/97f3bb73a29a566e99e33ae4338c2c3d9957e561
DIFF: https://github.com/llvm/llvm-project/commit/97f3bb73a29a566e99e33ae4338c2c3d9957e561.diff

LOG: [mlir][spirv] Support convert complex types

Complex types are converted to a two-element vector type to contain
the real and imaginary numbers.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D147188

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index af7436ee76741..bf74ad5ca139d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -16,11 +16,14 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 
 #include <functional>
+#include <optional>
 
 #define DEBUG_TYPE "mlir-spirv-conversion"
 
@@ -149,6 +152,13 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
     return bitWidth / 8;
   }
 
+  if (auto complexType = type.dyn_cast<ComplexType>()) {
+    auto elementSize = getTypeNumBytes(options, complexType.getElementType());
+    if (!elementSize)
+      return std::nullopt;
+    return 2 * *elementSize;
+  }
+
   if (auto vecType = type.dyn_cast<VectorType>()) {
     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
     if (!elementSize)
@@ -299,6 +309,30 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   return nullptr;
 }
 
+static Type
+convertComplexType(const spirv::TargetEnv &targetEnv,
+                   const SPIRVConversionOptions &options, ComplexType type,
+                   std::optional<spirv::StorageClass> storageClass = {}) {
+  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
+  if (!scalarType) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot convert non-scalar element type\n");
+    return nullptr;
+  }
+
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
+  if (!elementType)
+    return nullptr;
+  if (elementType != type.getElementType()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: complex type emulation unsupported\n");
+    return nullptr;
+  }
+
+  return VectorType::get(2, elementType);
+}
+
 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
 ///
 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
@@ -372,7 +406,6 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-
   if (!type.hasStaticShape()) {
     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
     // to the element.
@@ -419,6 +452,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   if (auto vecType = elementType.dyn_cast<VectorType>()) {
     arrayElemType =
         convertVectorType(targetEnv, options, vecType, storageClass);
+  } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+    arrayElemType =
+        convertComplexType(targetEnv, options, complexType, storageClass);
   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
     arrayElemType =
         convertScalarType(targetEnv, options, scalarType, storageClass);
@@ -443,7 +479,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-
   if (!type.hasStaticShape()) {
     // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
     // to the element.
@@ -500,6 +535,10 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
     return Type();
   });
 
+  addConversion([this](ComplexType complexType) {
+    return convertComplexType(this->targetEnv, this->options, complexType);
+  });
+
   addConversion([this](VectorType vectorType) {
     return convertVectorType(this->targetEnv, this->options, vectorType);
   });

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 7880547f72b95..fe6edc132e0e3 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -196,6 +196,61 @@ func.func @bf16_type(%arg0: bf16) { return }
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// Complex types
+//===----------------------------------------------------------------------===//
+
+// Check that capabilities for scalar types affects complex types too: having
+// special capabilities means keep vector types untouched.
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [Float64, StorageUniform16, StorageBuffer16BitAccess],
+    [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func @complex_types
+// CHECK-SAME: vector<2xf32>
+// CHECK-SAME: vector<2xf64>
+func.func @complex_types(
+    %arg0: complex<f32>,
+    %arg2: complex<f64>
+) { return }
+
+// CHECK-LABEL: func @memref_complex_types_with_cap
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<2xf16>, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<16 x vector<2xf16>, stride=4> [0])>, Uniform>
+func.func @memref_complex_types_with_cap(
+    %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+    %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that capabilities for scalar types affects complex types too: no special
+// capabilities available means widening element types to 32-bit.
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Emulation is unimplemented right now.
+// CHECK-LABEL: func @memref_complex_types_no_cap
+// CHECK-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+// NOEMU-LABEL: func @memref_complex_types_no_cap
+// NOEMU-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// NOEMU-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+func.func @memref_complex_types_no_cap(
+    %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+    %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Vector types
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list