[Mlir-commits] [mlir] 97f3bb7 - [mlir][spirv] Support convert complex types
Lei Zhang
llvmlistbot at llvm.org
Thu Mar 30 08:49:28 PDT 2023
Author: Lei Zhang
Date: 2023-03-30T08:48:31-07:00
New Revision: 97f3bb73a29a566e99e33ae4338c2c3d9957e561
URL: https://github.com/llvm/llvm-project/commit/97f3bb73a29a566e99e33ae4338c2c3d9957e561
DIFF: https://github.com/llvm/llvm-project/commit/97f3bb73a29a566e99e33ae4338c2c3d9957e561.diff
LOG: [mlir][spirv] Support convert complex types
Complex types are converted to a two-element vector type to contain
the real and imaginary numbers.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D147188
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index af7436ee76741..bf74ad5ca139d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -16,11 +16,14 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include <functional>
+#include <optional>
#define DEBUG_TYPE "mlir-spirv-conversion"
@@ -149,6 +152,13 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
+ if (auto complexType = type.dyn_cast<ComplexType>()) {
+ auto elementSize = getTypeNumBytes(options, complexType.getElementType());
+ if (!elementSize)
+ return std::nullopt;
+ return 2 * *elementSize;
+ }
+
if (auto vecType = type.dyn_cast<VectorType>()) {
auto elementSize = getTypeNumBytes(options, vecType.getElementType());
if (!elementSize)
@@ -299,6 +309,30 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
+static Type
+convertComplexType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options, ComplexType type,
+ std::optional<spirv::StorageClass> storageClass = {}) {
+ auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
+ if (!scalarType) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: cannot convert non-scalar element type\n");
+ return nullptr;
+ }
+
+ auto elementType =
+ convertScalarType(targetEnv, options, scalarType, storageClass);
+ if (!elementType)
+ return nullptr;
+ if (elementType != type.getElementType()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: complex type emulation unsupported\n");
+ return nullptr;
+ }
+
+ return VectorType::get(2, elementType);
+}
+
/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
///
/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
@@ -372,7 +406,6 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
-
if (!type.hasStaticShape()) {
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
// to the element.
@@ -419,6 +452,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
if (auto vecType = elementType.dyn_cast<VectorType>()) {
arrayElemType =
convertVectorType(targetEnv, options, vecType, storageClass);
+ } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+ arrayElemType =
+ convertComplexType(targetEnv, options, complexType, storageClass);
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
arrayElemType =
convertScalarType(targetEnv, options, scalarType, storageClass);
@@ -443,7 +479,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
-
if (!type.hasStaticShape()) {
// For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
// to the element.
@@ -500,6 +535,10 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
return Type();
});
+ addConversion([this](ComplexType complexType) {
+ return convertComplexType(this->targetEnv, this->options, complexType);
+ });
+
addConversion([this](VectorType vectorType) {
return convertVectorType(this->targetEnv, this->options, vectorType);
});
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 7880547f72b95..fe6edc132e0e3 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -196,6 +196,61 @@ func.func @bf16_type(%arg0: bf16) { return }
// -----
+//===----------------------------------------------------------------------===//
+// Complex types
+//===----------------------------------------------------------------------===//
+
+// Check that capabilities for scalar types affects complex types too: having
+// special capabilities means keep vector types untouched.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [Float64, StorageUniform16, StorageBuffer16BitAccess],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func @complex_types
+// CHECK-SAME: vector<2xf32>
+// CHECK-SAME: vector<2xf64>
+func.func @complex_types(
+ %arg0: complex<f32>,
+ %arg2: complex<f64>
+) { return }
+
+// CHECK-LABEL: func @memref_complex_types_with_cap
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<2xf16>, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<16 x vector<2xf16>, stride=4> [0])>, Uniform>
+func.func @memref_complex_types_with_cap(
+ %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+ %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that capabilities for scalar types affects complex types too: no special
+// capabilities available means widening element types to 32-bit.
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Emulation is unimplemented right now.
+// CHECK-LABEL: func @memref_complex_types_no_cap
+// CHECK-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+// NOEMU-LABEL: func @memref_complex_types_no_cap
+// NOEMU-SAME: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>
+// NOEMU-SAME: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+func.func @memref_complex_types_no_cap(
+ %arg0: memref<4xcomplex<f16>, #spirv.storage_class<StorageBuffer>>,
+ %arg1: memref<2x8xcomplex<f16>, #spirv.storage_class<Uniform>>
+) { return }
+
+} // end module
+
+// -----
+
//===----------------------------------------------------------------------===//
// Vector types
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list