[Mlir-commits] [mlir] 23b8264 - [mlir][spirv] Fix runtime array stride when emulating bitwidth

Lei Zhang llvmlistbot at llvm.org
Mon Apr 12 14:13:40 PDT 2021


Author: Lei Zhang
Date: 2021-04-12T17:13:33-04:00
New Revision: 23b8264b5255efdc7c87c189feab04520a1979d5

URL: https://github.com/llvm/llvm-project/commit/23b8264b5255efdc7c87c189feab04520a1979d5
DIFF: https://github.com/llvm/llvm-project/commit/23b8264b5255efdc7c87c189feab04520a1979d5.diff

LOG: [mlir][spirv] Fix runtime array stride when emulating bitwidth

The stride should be calculated with the converted array element
type, not the original input type.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D100337

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 9063b5f6cd67..ea703112a2da 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -442,8 +442,16 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
+  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
+  if (!arrayElemSize) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot deduce converted element size\n");
+    return nullptr;
+  }
+
   if (!type.hasStaticShape()) {
-    auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize);
+    auto arrayType =
+        spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
     // Wrap in a struct to satisfy Vulkan interface requirements.
     auto structType = spirv::StructType::get(arrayType, 0);
     return spirv::PointerType::get(structType, *storageClass);
@@ -458,12 +466,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
 
   auto arrayElemCount = *memrefSize / *elementSize;
 
-  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
-  if (!arrayElemSize) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot deduce converted element size\n");
-    return nullptr;
-  }
 
   auto arrayType =
       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index 58513124907a..1574ba24383e 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -296,6 +296,8 @@ module attributes {
 // An i1 is store in 8-bit, so 5xi1 has 40 bits, which is stored in 2xi32.
 // CHECK-LABEL: spv.func @memref_1bit_type
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<2 x i32, stride=4> [0])>, StorageBuffer>
+// NOEMU-LABEL: func @memref_1bit_type
+// NOEMU-SAME: memref<5xi1>
 func @memref_1bit_type(%arg0: memref<5xi1>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
@@ -509,12 +511,68 @@ module attributes {
 // CHECK-SAME: memref<*xi32>
 func @unranked_memref(%arg0: memref<*xi32>) { return }
 
+// Check that dynamic dims on i1 are not supported.
+// CHECK-LABEL: func @memref_1bit_type
+// CHECK-SAME: memref<?xi1>
+func @memref_1bit_type(%arg0: memref<?xi1>) { return }
+
 // CHECK-LABEL: func @dynamic_dim_memref
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
 func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
-                         %arg1: memref<?x?xf32>)
-{ return }
+                         %arg1: memref<?x?xf32>) { return }
+
+// Check that using non-32-bit scalar types in interface storage classes
+// requires special capability and extension: convert them to 32-bit if not
+// satisfied.
+
+// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+// NOEMU-LABEL: func @memref_8bit_StorageBuffer
+// NOEMU-SAME: memref<?xi8>
+func @memref_8bit_StorageBuffer(%arg0: memref<?xi8, 0>) { return }
+
+// CHECK-LABEL: spv.func @memref_8bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<si32, stride=4> [0])>, Uniform>
+// NOEMU-LABEL: func @memref_8bit_Uniform
+// NOEMU-SAME: memref<?xsi8, 4>
+func @memref_8bit_Uniform(%arg0: memref<?xsi8, 4>) { return }
+
+// CHECK-LABEL: spv.func @memref_8bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<ui32, stride=4> [0])>, PushConstant>
+// NOEMU-LABEL: func @memref_8bit_PushConstant
+// NOEMU-SAME: memref<?xui8, 7>
+func @memref_8bit_PushConstant(%arg0: memref<?xui8, 7>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+// NOEMU-LABEL: func @memref_16bit_StorageBuffer
+// NOEMU-SAME: memref<?xi16>
+func @memref_16bit_StorageBuffer(%arg0: memref<?xi16, 0>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<si32, stride=4> [0])>, Uniform>
+// NOEMU-LABEL: func @memref_16bit_Uniform
+// NOEMU-SAME: memref<?xsi16, 4>
+func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<ui32, stride=4> [0])>, PushConstant>
+// NOEMU-LABEL: func @memref_16bit_PushConstant
+// NOEMU-SAME: memref<?xui16, 7>
+func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Input
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Input>
+// NOEMU-LABEL: func @memref_16bit_Input
+// NOEMU-SAME: memref<?xf16, 9>
+func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
+
+// CHECK-LABEL: spv.func @memref_16bit_Output
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Output>
+// NOEMU-LABEL: func @memref_16bit_Output
+// NOEMU-SAME: memref<?xf16, 10>
+func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }
 
 } // end module
 


        


More information about the Mlir-commits mailing list