[Mlir-commits] [mlir] c29fc69 - [mlir][spirv] Support shaped types with index element
Jakub Kuderski
llvmlistbot at llvm.org
Wed Mar 1 11:04:26 PST 2023
Author: Jakub Kuderski
Date: 2023-03-01T14:04:19-05:00
New Revision: c29fc69e35bf3071621e561d9c1b40b630d24a91
URL: https://github.com/llvm/llvm-project/commit/c29fc69e35bf3071621e561d9c1b40b630d24a91
DIFF: https://github.com/llvm/llvm-project/commit/c29fc69e35bf3071621e561d9c1b40b630d24a91.diff
LOG: [mlir][spirv] Support shaped types with index element
This makes the SPIR-V type converter first convert `index` element types
to the right integer type.
Fixes: https://github.com/llvm/llvm-project/issues/61054
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D145090
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 0a1c887733d04..af7436ee76741 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -17,7 +17,6 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
@@ -115,8 +114,14 @@ wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
// Type Conversion
//===----------------------------------------------------------------------===//
+static spirv::ScalarType getIndexType(MLIRContext *ctx,
+ const SPIRVConversionOptions &options) {
+ return cast<spirv::ScalarType>(
+ IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
+}
+
Type SPIRVTypeConverter::getIndexType() const {
- return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
+ return ::getIndexType(getContext(), options);
}
MLIRContext *SPIRVTypeConverter::getContext() const {
@@ -242,12 +247,32 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
intType.getSignedness());
}
+/// Returns a type with the same shape but with any index element type converted
+/// to the matching integer type. This is a noop when the element type is not
+/// the index type.
+static ShapedType
+convertIndexElementType(ShapedType type,
+ const SPIRVConversionOptions &options) {
+ Type indexType = dyn_cast<IndexType>(type.getElementType());
+ if (!indexType)
+ return type;
+
+ return type.clone(getIndexType(type.getContext(), options));
+}
+
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
static Type
convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
- auto scalarType = type.getElementType().cast<spirv::ScalarType>();
+ type = cast<VectorType>(convertIndexElementType(type, options));
+ auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
+ if (!scalarType) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: cannot convert non-scalar element type\n");
+ return nullptr;
+ }
+
if (type.getRank() <= 1 && type.getNumElements() == 1)
return convertScalarType(targetEnv, options, scalarType, storageClass);
@@ -290,7 +315,8 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
+ type = cast<TensorType>(convertIndexElementType(type, options));
+ auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert non-scalar element type\n");
@@ -396,6 +422,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
arrayElemType =
convertScalarType(targetEnv, options, scalarType, storageClass);
+ } else if (auto indexType = elementType.dyn_cast<IndexType>()) {
+ type = convertIndexElementType(type, options).cast<MemRefType>();
+ arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 37ea7bf264b02..7880547f72b95 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -115,7 +115,8 @@ func.func @integer42(%arg0: i42) { return }
// Index type
//===----------------------------------------------------------------------===//
-// The index type is always converted into i32.
+// The index type is always converted into i32 or i64, with i32 being the
+// default.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {
@@ -223,6 +224,10 @@ func.func @float_vector(
// CHECK-SAME: %{{.+}}: i32
func.func @one_element_vector(%arg0: vector<1xi8>) { return }
+// CHECK-LABEL: spirv.func @index_vector
+// CHECK-SAME: %{{.*}}: vector<4xi32>
+func.func @index_vector(%arg0: vector<4xindex>) { return }
+
} // end module
// -----
@@ -313,6 +318,14 @@ func.func @memref_1bit_type(
%arg1: memref<4x8xi1, #spirv.storage_class<Function>>
) { return }
+// CHECK-LABEL: func @memref_index_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32>)>, Function>
+func.func @memref_index_type(
+ %arg0: memref<4xindex, #spirv.storage_class<StorageBuffer>>,
+ %arg1: memref<4xindex, #spirv.storage_class<Function>>
+) { return }
+
} // end module
// -----
@@ -819,6 +832,11 @@ func.func @float_tensor_types(
%arg2: tensor<8x4xf16>
) { return }
+
+// CHECK-LABEL: spirv.func @index_tensor_type
+// CHECK-SAME: %{{.*}}: !spirv.array<20 x i32>
+func.func @index_tensor_type(%arg0: tensor<4x5xindex>) { return }
+
} // end module
// -----
More information about the Mlir-commits
mailing list