[Mlir-commits] [mlir] bbffece - [mlir][spirv] Remove layout decoration on unneeded storage classes

Lei Zhang llvmlistbot at llvm.org
Thu Apr 28 05:22:00 PDT 2022


Author: Lei Zhang
Date: 2022-04-28T08:18:23-04:00
New Revision: bbffece3835d57ec09a1b62071ee8f4b17dd3c27

URL: https://github.com/llvm/llvm-project/commit/bbffece3835d57ec09a1b62071ee8f4b17dd3c27
DIFF: https://github.com/llvm/llvm-project/commit/bbffece3835d57ec09a1b62071ee8f4b17dd3c27.diff

LOG: [mlir][spirv] Remove layout decoration on unneeded storage classes

Per SPIR-V validation rules, explict layout decorations are only
needed for StorageBuffer, PhysicalStorageBuffer, Uniform, and
PushConstant storage classes. (And even that is for Shader
capabilities). So we don't need such decorations on the rest.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D124543

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
    mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
    mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
    mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
    mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 5174700bebcd..68cb4ee22554 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -370,7 +370,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
+  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
 }
 
 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
@@ -407,15 +407,15 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   if (!type.hasStaticShape()) {
-    auto arrayType =
-        spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
+    int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+    auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
     return wrapInStructAndGetPointer(arrayType, *storageClass);
   }
 
   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
   auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
-  auto arrayType =
-      spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
+  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 
   return wrapInStructAndGetPointer(arrayType, *storageClass);
 }
@@ -470,8 +470,8 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   if (!type.hasStaticShape()) {
-    auto arrayType =
-        spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
+    int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+    auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
     return wrapInStructAndGetPointer(arrayType, *storageClass);
   }
 
@@ -483,10 +483,8 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   auto arrayElemCount = *memrefSize / *elementSize;
-
-
-  auto arrayType =
-      spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
+  int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
+  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 
   return wrapInStructAndGetPointer(arrayType, *storageClass);
 }

diff  --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index b40d01cb0b8f..2f44420b7a00 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -435,17 +435,17 @@ func.func @constant() {
   %3 = arith.constant dense<[2, 3]> : vector<2xi32>
   // CHECK: spv.Constant 1 : i32
   %4 = arith.constant 1 : index
-  // CHECK: spv.Constant dense<1> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<1> : tensor<6xi32> : !spv.array<6 x i32>
   %5 = arith.constant dense<1> : tensor<2x3xi32>
-  // CHECK: spv.Constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32, stride=4>
+  // CHECK: spv.Constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32>
   %6 = arith.constant dense<1.0> : tensor<2x3xf32>
-  // CHECK: spv.Constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32>
   %7 = arith.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
-  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32>
   %8 = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
-  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32>
   %9 =  arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
-  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32>
   %10 =  arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
   return
 }
@@ -458,7 +458,7 @@ func.func @constant_16bit() {
   %1 = arith.constant 5.0 : f16
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi16>
   %2 = arith.constant dense<[2, 3]> : vector<2xi16>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16, stride=2>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16>
   %3 = arith.constant dense<4.0> : tensor<5xf16>
   return
 }
@@ -471,7 +471,7 @@ func.func @constant_64bit() {
   %1 = arith.constant 5.0 : f64
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi64>
   %2 = arith.constant dense<[2, 3]> : vector<2xi64>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64, stride=8>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64>
   %3 = arith.constant dense<4.0> : tensor<5xf64>
   return
 }
@@ -504,9 +504,9 @@ func.func @constant_16bit() {
   %1 = arith.constant 5.0 : f16
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi32>
   %2 = arith.constant dense<[2, 3]> : vector<2xi16>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32>
   %3 = arith.constant dense<4.0> : tensor<5xf16>
-  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4>
+  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32>
   %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
   return
 }
@@ -519,9 +519,9 @@ func.func @constant_64bit() {
   %1 = arith.constant 5.0 : f64
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi32>
   %2 = arith.constant dense<[2, 3]> : vector<2xi64>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32>
   %3 = arith.constant dense<4.0> : tensor<5xf64>
-  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4>
+  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32>
   %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
   return
 }
@@ -1360,17 +1360,17 @@ func.func @constant() {
   %3 = arith.constant dense<[2, 3]> : vector<2xi32>
   // CHECK: spv.Constant 1 : i32
   %4 = arith.constant 1 : index
-  // CHECK: spv.Constant dense<1> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<1> : tensor<6xi32> : !spv.array<6 x i32>
   %5 = arith.constant dense<1> : tensor<2x3xi32>
-  // CHECK: spv.Constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32, stride=4>
+  // CHECK: spv.Constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32>
   %6 = arith.constant dense<1.0> : tensor<2x3xf32>
-  // CHECK: spv.Constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32>
   %7 = arith.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
-  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32>
   %8 = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
-  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32>
   %9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
-  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32, stride=4>
+  // CHECK: spv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32>
   %10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
   return
 }
@@ -1383,7 +1383,7 @@ func.func @constant_16bit() {
   %1 = arith.constant 5.0 : f16
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi16>
   %2 = arith.constant dense<[2, 3]> : vector<2xi16>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16, stride=2>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf16> : !spv.array<5 x f16>
   %3 = arith.constant dense<4.0> : tensor<5xf16>
   return
 }
@@ -1396,7 +1396,7 @@ func.func @constant_64bit() {
   %1 = arith.constant 5.0 : f64
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi64>
   %2 = arith.constant dense<[2, 3]> : vector<2xi64>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64, stride=8>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf64> : !spv.array<5 x f64>
   %3 = arith.constant dense<4.0> : tensor<5xf64>
   return
 }
@@ -1418,9 +1418,9 @@ func.func @constant_16bit() {
   %1 = arith.constant 5.0 : f16
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi32>
   %2 = arith.constant dense<[2, 3]> : vector<2xi16>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32>
   %3 = arith.constant dense<4.0> : tensor<5xf16>
-  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4>
+  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32>
   %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
   return
 }
@@ -1433,9 +1433,9 @@ func.func @constant_64bit() {
   %1 = arith.constant 5.0 : f64
   // CHECK: spv.Constant dense<[2, 3]> : vector<2xi32>
   %2 = arith.constant dense<[2, 3]> : vector<2xi64>
-  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32, stride=4>
+  // CHECK: spv.Constant dense<4.000000e+00> : tensor<5xf32> : !spv.array<5 x f32>
   %3 = arith.constant dense<4.0> : tensor<5xf64>
-  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32, stride=4>
+  // CHECK: spv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spv.array<4 x f32>
   %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
   return
 }

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 3fb5d7254704..cce9e2749de3 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -282,6 +282,17 @@ func.func @memref_mem_space(
     %arg5: memref<4xf32, 6>
 ) { return }
 
+// CHECK-LABEL: func @memref_1bit_type
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32>)>, Function>
+// NOEMU-LABEL: func @memref_1bit_type
+// NOEMU-SAME: memref<4x8xi1>
+// NOEMU-SAME: memref<4x8xi1, 6>
+func.func @memref_1bit_type(
+    %arg0: memref<4x8xi1, 0>,
+    %arg1: memref<4x8xi1, 6>
+) { return }
+
 } // end module
 
 // -----
@@ -337,13 +348,13 @@ func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
 func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
 // NOEMU-SAME: memref<16xf16, 9>
 func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
 // NOEMU-SAME: memref<16xf16, 10>
 func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
@@ -451,15 +462,15 @@ module attributes {
 } {
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16>)>, Input>
 // NOEMU-LABEL: spv.func @memref_16bit_Input
-// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16>)>, Input>
 func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
 // NOEMU-LABEL: spv.func @memref_16bit_Output
-// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
 func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
 
 } // end module
@@ -564,13 +575,13 @@ func.func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
 func.func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
 // NOEMU-SAME: memref<?xf16, 9>
 func.func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
 // NOEMU-SAME: memref<?xf16, 10>
 func.func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }
@@ -629,10 +640,10 @@ module attributes {
 } {
 
 // CHECK-LABEL: spv.func @int_tensor_types
-// CHECK-SAME: !spv.array<32 x i64, stride=8>
-// CHECK-SAME: !spv.array<32 x i32, stride=4>
-// CHECK-SAME: !spv.array<32 x i16, stride=2>
-// CHECK-SAME: !spv.array<32 x i8, stride=1>
+// CHECK-SAME: !spv.array<32 x i64>
+// CHECK-SAME: !spv.array<32 x i32>
+// CHECK-SAME: !spv.array<32 x i16>
+// CHECK-SAME: !spv.array<32 x i8>
 func.func @int_tensor_types(
   %arg0: tensor<8x4xi64>,
   %arg1: tensor<8x4xi32>,
@@ -641,9 +652,9 @@ func.func @int_tensor_types(
 ) { return }
 
 // CHECK-LABEL: spv.func @float_tensor_types
-// CHECK-SAME: !spv.array<32 x f64, stride=8>
-// CHECK-SAME: !spv.array<32 x f32, stride=4>
-// CHECK-SAME: !spv.array<32 x f16, stride=2>
+// CHECK-SAME: !spv.array<32 x f64>
+// CHECK-SAME: !spv.array<32 x f32>
+// CHECK-SAME: !spv.array<32 x f16>
 func.func @float_tensor_types(
   %arg0: tensor<8x4xf64>,
   %arg1: tensor<8x4xf32>,
@@ -660,10 +671,10 @@ module attributes {
 } {
 
 // CHECK-LABEL: spv.func @int_tensor_types
-// CHECK-SAME: !spv.array<32 x i32, stride=4>
-// CHECK-SAME: !spv.array<32 x i32, stride=4>
-// CHECK-SAME: !spv.array<32 x i32, stride=4>
-// CHECK-SAME: !spv.array<32 x i32, stride=4>
+// CHECK-SAME: !spv.array<32 x i32>
+// CHECK-SAME: !spv.array<32 x i32>
+// CHECK-SAME: !spv.array<32 x i32>
+// CHECK-SAME: !spv.array<32 x i32>
 func.func @int_tensor_types(
   %arg0: tensor<8x4xi64>,
   %arg1: tensor<8x4xi32>,
@@ -672,9 +683,9 @@ func.func @int_tensor_types(
 ) { return }
 
 // CHECK-LABEL: spv.func @float_tensor_types
-// CHECK-SAME: !spv.array<32 x f32, stride=4>
-// CHECK-SAME: !spv.array<32 x f32, stride=4>
-// CHECK-SAME: !spv.array<32 x f32, stride=4>
+// CHECK-SAME: !spv.array<32 x f32>
+// CHECK-SAME: !spv.array<32 x f32>
+// CHECK-SAME: !spv.array<32 x f32>
 func.func @float_tensor_types(
   %arg0: tensor<8x4xf64>,
   %arg1: tensor<8x4xf32>,

diff  --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
index 6d022d70250b..3c80acfad14d 100644
--- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
@@ -9,7 +9,7 @@ module attributes {
     //       CHECK:   spv.func
     //  CHECK-SAME:     {{%.*}}: f32
     //   CHECK-NOT:     spv.interface_var_abi
-    //  CHECK-SAME:     {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4>)>, CrossWorkgroup>
+    //  CHECK-SAME:     {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>
     //   CHECK-NOT:     spv.interface_var_abi
     //  CHECK-SAME:     spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
     gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 598e03fed55f..434eec1aaf30 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -13,7 +13,7 @@ module attributes {
     return
   }
 }
-//     CHECK: spv.GlobalVariable @[[VAR:.+]] : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Workgroup>
+//     CHECK: spv.GlobalVariable @[[VAR:.+]] : !spv.ptr<!spv.struct<(!spv.array<20 x f32>)>, Workgroup>
 //     CHECK: func @alloc_dealloc_workgroup_mem
 // CHECK-NOT:   memref.alloc
 //     CHECK:   %[[PTR:.+]] = spv.mlir.addressof @[[VAR]]
@@ -40,7 +40,7 @@ module attributes {
 }
 
 //       CHECK: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
-//  CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<20 x i32, stride=4>)>, Workgroup>
+//  CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<20 x i32>)>, Workgroup>
 // CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem
 //       CHECK:   %[[VAR:.+]] = spv.mlir.addressof @__workgroup_mem__0
 //       CHECK:   %[[LOC:.+]] = spv.SDiv
@@ -67,9 +67,9 @@ module attributes {
 }
 
 //  CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
-// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4>)>, Workgroup>
+// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<6 x i32>)>, Workgroup>
 //  CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
-// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Workgroup>
+// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<20 x f32>)>, Workgroup>
 //      CHECK: func @two_allocs()
 
 // -----
@@ -87,9 +87,9 @@ module attributes {
 }
 
 //  CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
-// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>, stride=8>)>, Workgroup>
+// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>>)>, Workgroup>
 //  CHECK-DAG: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
-// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16>)>, Workgroup>
+// CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>>)>, Workgroup>
 //      CHECK: func @two_allocs_vector()
 
 

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
index 2aabeed2fd8c..5d11235568b8 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
@@ -10,7 +10,7 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
 }
 
 // CHECK-LABEL: func @alloc_function_variable
-//       CHECK:   %[[VAR:.+]] = spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
+//       CHECK:   %[[VAR:.+]] = spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32>)>, Function>
 //       CHECK:   %[[LOADPTR:.+]] = spv.AccessChain %[[VAR]]
 //       CHECK:   %[[VAL:.+]] = spv.Load "Function" %[[LOADPTR]] : f32
 //       CHECK:   %[[STOREPTR:.+]] = spv.AccessChain %[[VAR]]
@@ -28,8 +28,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
 }
 
 // CHECK-LABEL: func @two_allocs
-//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4>)>, Function>
-//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<6 x i32>)>, Function>
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32>)>, Function>
 
 // -----
 
@@ -42,8 +42,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>
 }
 
 // CHECK-LABEL: func @two_allocs_vector
-//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>, stride=8>)>, Function>
-//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16>)>, Function>
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>>)>, Function>
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>>)>, Function>
 
 
 // -----

diff  --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
index 68963d2f61e1..cbec0f5af994 100644
--- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir
@@ -9,7 +9,7 @@
 func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
   // CHECK: %[[CST:.+]] = spv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]>
   %cst = arith.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32>
-  // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr<!spv.array<12 x i32, stride=4>, Function>
+  // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr<!spv.array<12 x i32>, Function>
   // CHECK: %[[C0:.+]] = spv.Constant 0 : i32
   // CHECK: %[[C6:.+]] = spv.Constant 6 : i32
   // CHECK: %[[MUL0:.+]] = spv.IMul %[[C6]], %[[A]] : i32


        


More information about the Mlir-commits mailing list