[Mlir-commits] [mlir] c29fc69 - [mlir][spirv] Support shaped types with index element

Jakub Kuderski llvmlistbot at llvm.org
Wed Mar 1 11:04:26 PST 2023


Author: Jakub Kuderski
Date: 2023-03-01T14:04:19-05:00
New Revision: c29fc69e35bf3071621e561d9c1b40b630d24a91

URL: https://github.com/llvm/llvm-project/commit/c29fc69e35bf3071621e561d9c1b40b630d24a91
DIFF: https://github.com/llvm/llvm-project/commit/c29fc69e35bf3071621e561d9c1b40b630d24a91.diff

LOG: [mlir][spirv] Support shaped types with index element

This makes the SPIR-V type converter first convert `index` element types
to the right integer type.

Fixes: https://github.com/llvm/llvm-project/issues/61054

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D145090

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 0a1c887733d04..af7436ee76741 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -17,7 +17,6 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 
@@ -115,8 +114,14 @@ wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
+static spirv::ScalarType getIndexType(MLIRContext *ctx,
+                                      const SPIRVConversionOptions &options) {
+  return cast<spirv::ScalarType>(
+      IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
+}
+
 Type SPIRVTypeConverter::getIndexType() const {
-  return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
+  return ::getIndexType(getContext(), options);
 }
 
 MLIRContext *SPIRVTypeConverter::getContext() const {
@@ -242,12 +247,32 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
                           intType.getSignedness());
 }
 
+/// Returns a type with the same shape but with any index element type converted
+/// to the matching integer type. This is a noop when the element type is not
+/// the index type.
+static ShapedType
+convertIndexElementType(ShapedType type,
+                        const SPIRVConversionOptions &options) {
+  Type indexType = dyn_cast<IndexType>(type.getElementType());
+  if (!indexType)
+    return type;
+
+  return type.clone(getIndexType(type.getContext(), options));
+}
+
 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
 static Type
 convertVectorType(const spirv::TargetEnv &targetEnv,
                   const SPIRVConversionOptions &options, VectorType type,
                   std::optional<spirv::StorageClass> storageClass = {}) {
-  auto scalarType = type.getElementType().cast<spirv::ScalarType>();
+  type = cast<VectorType>(convertIndexElementType(type, options));
+  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
+  if (!scalarType) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot convert non-scalar element type\n");
+    return nullptr;
+  }
+
   if (type.getRank() <= 1 && type.getNumElements() == 1)
     return convertScalarType(targetEnv, options, scalarType, storageClass);
 
@@ -290,7 +315,8 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
+  type = cast<TensorType>(convertIndexElementType(type, options));
+  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     LLVM_DEBUG(llvm::dbgs()
                << type << " illegal: cannot convert non-scalar element type\n");
@@ -396,6 +422,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
     arrayElemType =
         convertScalarType(targetEnv, options, scalarType, storageClass);
+  } else if (auto indexType = elementType.dyn_cast<IndexType>()) {
+    type = convertIndexElementType(type, options).cast<MemRefType>();
+    arrayElemType = type.getElementType();
   } else {
     LLVM_DEBUG(
         llvm::dbgs()

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 37ea7bf264b02..7880547f72b95 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -115,7 +115,8 @@ func.func @integer42(%arg0: i42) { return }
 // Index type
 //===----------------------------------------------------------------------===//
 
-// The index type is always converted into i32.
+// The index type is always converted into i32 or i64, with i32 being the
+// default.
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
 } {
@@ -223,6 +224,10 @@ func.func @float_vector(
 // CHECK-SAME: %{{.+}}: i32
 func.func @one_element_vector(%arg0: vector<1xi8>) { return }
 
+// CHECK-LABEL: spirv.func @index_vector
+// CHECK-SAME: %{{.*}}: vector<4xi32>
+func.func @index_vector(%arg0: vector<4xindex>) { return }
+
 } // end module
 
 // -----
@@ -313,6 +318,14 @@ func.func @memref_1bit_type(
     %arg1: memref<4x8xi1, #spirv.storage_class<Function>>
 ) { return }
 
+// CHECK-LABEL: func @memref_index_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32>)>, Function>
+func.func @memref_index_type(
+    %arg0: memref<4xindex, #spirv.storage_class<StorageBuffer>>,
+    %arg1: memref<4xindex, #spirv.storage_class<Function>>
+) { return }
+
 } // end module
 
 // -----
@@ -819,6 +832,11 @@ func.func @float_tensor_types(
   %arg2: tensor<8x4xf16>
 ) { return }
 
+
+// CHECK-LABEL: spirv.func @index_tensor_type
+// CHECK-SAME: %{{.*}}: !spirv.array<20 x i32>
+func.func @index_tensor_type(%arg0: tensor<4x5xindex>) { return }
+
 } // end module
 
 // -----


        


More information about the Mlir-commits mailing list