[Mlir-commits] [mlir] 02022ff - [mlir][spirv] Enhance AccessChainOp index type handling

Lei Zhang llvmlistbot at llvm.org
Mon Jun 22 07:11:47 PDT 2020


Author: HazemAbdelhafez
Date: 2020-06-22T10:11:33-04:00
New Revision: 02022ff2e3f2f4d39d5926e93ecfa77395a1e9ea

URL: https://github.com/llvm/llvm-project/commit/02022ff2e3f2f4d39d5926e93ecfa77395a1e9ea
DIFF: https://github.com/llvm/llvm-project/commit/02022ff2e3f2f4d39d5926e93ecfa77395a1e9ea.diff

LOG: [mlir][spirv] Enhance AccessChainOp index type handling

This patch extends the AccessChainOp index type handling to be able to deal with
all Integer type indices (i.e., all bit-widths and signedness symantics).

There were two ways of achieving this:
1- Backward compatible: The new way of handling the indices will assume that
   an index type is i32 by default if not specified in the assembly format,
   this way all the old tests would pass correctly.
2- Enforce the format: This unifies the spv.AccessChain Op format and all the old
   tests had to be updated to reflect this change or else they fail.

I picked option-2 to unify the Op format and avoid having optional index-type fields
that can lead to somewhat confusing tests format and multiple representations for
the same Op with undocumented assumption that an index is i32 unless stated.
Nonetheless, reverting to option-1 should be straightforward if preferred or needed.

Differential Revision: https://reviews.llvm.org/D81763

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
    mlir/test/Dialect/SPIRV/Serialization/array.mlir
    mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/Serialization/debug.mlir
    mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir
    mlir/test/Dialect/SPIRV/Serialization/loop.mlir
    mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
    mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
    mlir/test/Dialect/SPIRV/Serialization/undef.mlir
    mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
    mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
    mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
    mlir/test/Dialect/SPIRV/canonicalize.mlir
    mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
    mlir/test/Dialect/SPIRV/ops.mlir
    mlir/test/Dialect/SPIRV/structure-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 87456f000edc..3742bab414ec 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -723,12 +723,6 @@ static LogicalResult verifyShiftOp(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
-  if (indices.empty()) {
-    emitError(baseLoc, "'spv.AccessChain' op expected at least "
-                       "one index ");
-    return nullptr;
-  }
-
   auto ptrType = type.dyn_cast<spirv::PointerType>();
   if (!ptrType) {
     emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
@@ -791,19 +785,37 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
   OpAsmParser::OperandType ptrInfo;
   SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
   Type type;
-  // TODO(denis0x0D): regarding to the spec an index must be any integer type,
-  // figure out how to use resolveOperand with a range of types and do not
-  // fail on first attempt.
-  Type indicesType = parser.getBuilder().getIntegerType(32);
+  auto loc = parser.getCurrentLocation();
+  SmallVector<Type, 4> indicesTypes;
 
   if (parser.parseOperand(ptrInfo) ||
       parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
       parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, state.operands) ||
-      parser.resolveOperands(indicesInfo, indicesType, state.operands)) {
+      parser.resolveOperand(ptrInfo, type, state.operands)) {
     return failure();
   }
 
+  // Check that the provided indices list is not empty before parsing their
+  // type list.
+  if (indicesInfo.empty()) {
+    return emitError(state.location, "'spv.AccessChain' op expected at "
+                                     "least one index ");
+  }
+
+  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
+    return failure();
+
+  // Check that the indices types list is not empty and that it has a one-to-one
+  // mapping to the provided indices.
+  if (indicesTypes.size() != indicesInfo.size()) {
+    return emitError(state.location, "'spv.AccessChain' op indices "
+                                     "types' count must be equal to indices "
+                                     "info count");
+  }
+
+  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
+    return failure();
+
   auto resultType = getElementPtrType(
       type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
   if (!resultType) {
@@ -816,7 +828,8 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
 
 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
   printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
-          << '[' << op.indices() << "] : " << op.base_ptr().getType();
+          << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
+          << op.indices().getTypes();
 }
 
 static LogicalResult verify(spirv::AccessChainOp accessChainOp) {

diff  --git a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
index 726b276010ef..43da9c5be429 100644
--- a/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
+++ b/mlir/test/Conversion/GPUToVulkan/lower-gpu-launch-vulkan-launch.mlir
@@ -11,7 +11,7 @@ module attributes {gpu.container_module} {
       %0 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
       %2 = spv.constant 0 : i32
       %3 = spv._address_of @kernel_arg_0 : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
-      %4 = spv.AccessChain %0[%2, %2] : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>
+      %4 = spv.AccessChain %0[%2, %2] : !spv.ptr<!spv.struct<!spv.array<12 x f32, stride=4> [0]>, StorageBuffer>, i32, i32
       %5 = spv.Load "StorageBuffer" %4 : f32
       spv.Return
     }

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/array.mlir b/mlir/test/Dialect/SPIRV/Serialization/array.mlir
index 0d14cbf9d3b6..68986741b32d 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/array.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/array.mlir
@@ -2,8 +2,8 @@
 
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   spv.func @array_stride(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32, stride=4>, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" {
-    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32, stride=4>, stride=128>, StorageBuffer>
-    %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32, stride=4>, stride=128>, StorageBuffer>
+    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32, stride=4>, stride=128>, StorageBuffer>, i32, i32
+    %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32, stride=4>, stride=128>, StorageBuffer>, i32, i32
     spv.Return
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
index 0d58fea18a11..ad913dfb1624 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
@@ -95,8 +95,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_N
   // CHECK-LABEL: @cooperative_matrix_access_chain
   spv.func @cooperative_matrix_access_chain(%a : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
     %0 = spv.constant 0: i32
-    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
-    %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
+    // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+    %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
     spv.ReturnValue %1 : !spv.ptr<f32, Function>
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir
index aa9653da4f2e..d83030d25298 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir
@@ -58,7 +58,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
 
   spv.func @memory_accesses(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" {
     // CHECK: loc({{".*debug.mlir"}}:61:10)
-    %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, StorageBuffer>
+    %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, StorageBuffer>, i32, i32
     // CHECK: loc({{".*debug.mlir"}}:63:10)
     %3 = spv.Load "StorageBuffer" %2 : f32
     // CHECK: loc({{.*debug.mlir"}}:65:5)

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir
index faa371ea9016..0365fef03fbb 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/global-variable.mlir
@@ -30,7 +30,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv._address_of @globalInvocationID : !spv.ptr<vector<3xi32>, Input>
     %1 = spv.constant 0: i32
     // CHECK: spv.AccessChain %[[ADDR]]
-    %2 = spv.AccessChain %0[%1] : !spv.ptr<vector<3xi32>, Input>
+    %2 = spv.AccessChain %0[%1] : !spv.ptr<vector<3xi32>, Input>, i32
     spv.Return
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
index 1b041a4aa604..d6e2090f02bb 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/loop.mlir
@@ -65,9 +65,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   spv.func @loop_kernel() "None" {
     %0 = spv._address_of @GV1 : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
     %1 = spv.constant 0 : i32
-    %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+    %2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>, i32
     %3 = spv._address_of @GV2 : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
-    %5 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
+    %5 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>, i32
     %6 = spv.constant 4 : i32
     %7 = spv.constant 42 : i32
     %8 = spv.constant 2 : i32
@@ -84,9 +84,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
       spv.BranchConditional %10, ^body, ^merge
 // CHECK-NEXT:   ^bb2:     // pred: ^bb1
     ^body:
-      %11 = spv.AccessChain %2[%9] : !spv.ptr<!spv.array<10 x f32, stride=4>, StorageBuffer>
+      %11 = spv.AccessChain %2[%9] : !spv.ptr<!spv.array<10 x f32, stride=4>, StorageBuffer>, i32
       %12 = spv.Load "StorageBuffer" %11 : f32
-      %13 = spv.AccessChain %5[%9] : !spv.ptr<!spv.array<10 x f32, stride=4>, StorageBuffer>
+      %13 = spv.AccessChain %5[%9] : !spv.ptr<!spv.array<10 x f32, stride=4>, StorageBuffer>, i32
       spv.Store "StorageBuffer" %13, %12 : f32
 // CHECK:          %[[ADD:.*]] = spv.IAdd
       %14 = spv.IAdd %9, %8 : i32

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
index 8dc90cb504e8..e10bfc88afb0 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir
@@ -4,7 +4,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK-LABEL: @matrix_access_chain
   spv.func @matrix_access_chain(%arg0 : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>, %arg1 : i32) -> !spv.ptr<vector<3xf32>, Function> "None" {
     // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
-    %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>, Function>
+    %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.matrix<3 x vector<3xf32>>,Function>, i32
     spv.ReturnValue %0 : !spv.ptr<vector<3xf32>, Function>
   }
 
@@ -20,6 +20,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
     %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>>
     spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>>
+
   }
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
index fbe45fa87d20..26584a479dec 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir
@@ -18,8 +18,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   spv.func @access_chain(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, %arg1 : i32, %arg2 : i32) "None" {
     // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
     // CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
-    %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
-    %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+    %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
+    %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32, i32
     spv.Return
   }
 }
@@ -31,13 +31,13 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>
     // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : f32
     %0 = spv.constant 0 : i32
-    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>
+    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>, i32, i32
     %2 = spv.Load "StorageBuffer" %1 : f32
 
     // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>
     // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : f32
     %3 = spv.constant 0 : i32
-    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>
+    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x f32, stride=4> [0]>, StorageBuffer>, i32, i32
     spv.Store "StorageBuffer" %4, %2 : f32
     spv.Return
   }
@@ -46,13 +46,13 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     // CHECK: [[LOAD_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>
     // CHECK-NEXT: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOAD_PTR]] : i32
     %0 = spv.constant 0 : i32
-    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>
+    %1 = spv.AccessChain %arg0[%0, %0] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>, i32, i32
     %2 = spv.Load "StorageBuffer" %1 : i32
 
     // CHECK: [[STORE_PTR:%.*]] = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>
     // CHECK-NEXT: spv.Store "StorageBuffer" [[STORE_PTR]], [[VAL]] : i32
     %3 = spv.constant 0 : i32
-    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>
+    %4 = spv.AccessChain %arg1[%3, %3] : !spv.ptr<!spv.struct<!spv.array<1 x i32, stride=4> [0]>, StorageBuffer>, i32, i32
     spv.Store "StorageBuffer" %4, %2 : i32
     spv.Return
   }

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir
index 6998930911db..d19812f48257 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/undef.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/undef.mlir
@@ -16,7 +16,7 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     // CHECK: {{%.*}} = spv.undef : !spv.ptr<!spv.struct<f32>, StorageBuffer>
     %7 = spv.undef : !spv.ptr<!spv.struct<f32>, StorageBuffer>
     %8 = spv.constant 0 : i32
-    %9 = spv.AccessChain %7[%8] : !spv.ptr<!spv.struct<f32>, StorageBuffer>
+    %9 = spv.AccessChain %7[%8] : !spv.ptr<!spv.struct<f32>, StorageBuffer>, i32
     spv.Return
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index 075ef3398d83..4e4bf06e6f73 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -102,14 +102,14 @@ spv.module Logical GLSL450 {
     %37 = spv.IAdd %arg4, %11 : i32
     // CHECK: spv.AccessChain [[ARG0]]
     %c0 = spv.constant 0 : i32
-    %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
+    %38 = spv.AccessChain %arg0[%c0, %36, %37] : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>, i32, i32, i32
     %39 = spv.Load "StorageBuffer" %38 : f32
     // CHECK: spv.AccessChain [[ARG1]]
-    %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
+    %40 = spv.AccessChain %arg1[%c0, %36, %37] : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>, i32, i32, i32
     %41 = spv.Load "StorageBuffer" %40 : f32
     %42 = spv.FAdd %39, %41 : f32
     // CHECK: spv.AccessChain [[ARG2]]
-    %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
+    %43 = spv.AccessChain %arg2[%c0, %36, %37] : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>, i32, i32, i32
     spv.Store "StorageBuffer" %43, %42 : f32
     spv.Return
   }

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
index 24b22c5dbf6f..8c9408ab089f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
@@ -37,7 +37,7 @@ spv.module Logical GLSL450 {
   spv.func @callee() "None" {
     %0 = spv._address_of @data : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
     %1 = spv.constant 0: i32
-    %2 = spv.AccessChain %0[%1, %1] : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>
+    %2 = spv.AccessChain %0[%1, %1] : !spv.ptr<!spv.struct<!spv.rtarray<i32> [0]>, StorageBuffer>, i32, i32
     spv.Branch ^next
 
   ^next:
@@ -196,7 +196,7 @@ spv.module Logical GLSL450 {
     // CHECK: [[VAL:%.*]] = spv.Load "StorageBuffer" [[LOADPTR]]
     %2 = spv._address_of @arg_0 : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
     %3 = spv._address_of @arg_1 : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
-    %4 = spv.AccessChain %2[%1] : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+    %4 = spv.AccessChain %2[%1] : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>, i32
     %5 = spv.Load "StorageBuffer" %4 : i32
     %6 = spv.SGreaterThan %5, %1 : i32
     // CHECK: spv.selection
@@ -204,7 +204,7 @@ spv.module Logical GLSL450 {
       spv.BranchConditional %6, ^bb1, ^bb2
     ^bb1: // pred: ^bb0
       // CHECK: [[STOREPTR:%.*]] = spv.AccessChain [[ADDRESS_ARG1]]
-      %7 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
+      %7 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>, i32
       // CHECK-NOT: spv.FunctionCall
       // CHECK: spv.AtomicIAdd "Device" "AcquireRelease" [[STOREPTR]], [[VAL]]
       // CHECK: spv.Branch

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
index 975012d3a26a..f54d1910be22 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/layout-decoration.mlir
@@ -24,7 +24,7 @@ spv.module Logical GLSL450 {
     // CHECK: {{%.*}} = spv._address_of @var0 : !spv.ptr<!spv.struct<i32 [0], !spv.struct<f32 [0], i32 [4]> [4], f32 [12]>, Uniform>
     %0 = spv._address_of @var0 : !spv.ptr<!spv.struct<i32, !spv.struct<f32, i32>, f32>, Uniform>
     // CHECK:  {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.struct<i32 [0], !spv.struct<f32 [0], i32 [4]> [4], f32 [12]>, Uniform>
-    %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<i32, !spv.struct<f32, i32>, f32>, Uniform>
+    %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<i32, !spv.struct<f32, i32>, f32>, Uniform>, i32
     spv.Return
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir
index 20ed6e96be8d..2b719fd7219d 100644
--- a/mlir/test/Dialect/SPIRV/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir
@@ -11,8 +11,8 @@ func @combine_full_access_chain() -> f32 {
   // CHECK-NEXT: spv.Load "Function" %[[PTR]]
   %c0 = spv.constant 0: i32
   %0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
-  %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
-  %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>, i32
+  %2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32, i32
   %3 = spv.Load "Function" %2 : f32
   spv.ReturnValue %3 : f32
 }
@@ -28,9 +28,9 @@ func @combine_access_chain_multi_use() -> !spv.array<4xf32> {
   // CHECK-NEXT: spv.Load "Function" %[[PTR_1]]
   %c0 = spv.constant 0: i32
   %0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
-  %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
-  %2 = spv.AccessChain %1[%c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
-  %3 = spv.AccessChain %2[%c0] : !spv.ptr<!spv.array<4xf32>, Function>
+  %1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>, i32
+  %2 = spv.AccessChain %1[%c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
+  %3 = spv.AccessChain %2[%c0] : !spv.ptr<!spv.array<4xf32>, Function>, i32
   %4 = spv.Load "Function" %2 : !spv.array<4xf32>
   %5 = spv.Load "Function" %3 : f32
   spv.ReturnValue %4: !spv.array<4xf32>
@@ -49,8 +49,8 @@ func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> {
   %c1 = spv.constant 1: i32
   %0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
   %1 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
-  %2 = spv.AccessChain %0[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
-  %3 = spv.AccessChain %1[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
+  %2 = spv.AccessChain %0[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>, i32
+  %3 = spv.AccessChain %1[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>, i32
   %4 = spv.Load "Function" %2 : !spv.array<4xi32>
   %5 = spv.Load "Function" %3 : !spv.array<4xi32>
   spv.ReturnValue %4 : !spv.array<4xi32>

diff  --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
index a2dafaddfa21..523bd6bfb030 100644
--- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
+++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir
@@ -97,8 +97,8 @@ spv.func @cooperative_matrix_fdiv(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b :
 // CHECK-LABEL: @cooperative_matrix_access_chain
 spv.func @cooperative_matrix_access_chain(%a : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" {
   %0 = spv.constant 0: i32
-  // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
-  %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>
+  // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
+  %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.coopmatrix<8x16xf32, Subgroup>, Function>, i32
   spv.ReturnValue %1 : !spv.ptr<f32, Function>
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index 14e1fa10735e..7dea7942d426 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -8,21 +8,21 @@ func @access_chain_struct() -> () {
   %0 = spv.constant 1: i32
   %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
   // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Function>
-  %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
   return
 }
 
 func @access_chain_1D_array(%arg0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.array<4xf32>, Function>
   // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x f32>, Function>
-  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4xf32>, Function>
+  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4xf32>, Function>, i32
   return
 }
 
 func @access_chain_2D_array_1(%arg0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
   // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
-  %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%arg0, %arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32, i32
   %2 = spv.Load "Function" %1 ["Volatile"] : f32
   return
 }
@@ -30,7 +30,7 @@ func @access_chain_2D_array_1(%arg0 : i32) -> () {
 func @access_chain_2D_array_2(%arg0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
   // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
-  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
   %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32>
   return
 }
@@ -38,7 +38,7 @@ func @access_chain_2D_array_2(%arg0 : i32) -> () {
 func @access_chain_rtarray(%arg0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.rtarray<f32>, Function>
   // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.rtarray<f32>, Function>
-  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.rtarray<f32>, Function>
+  %1 = spv.AccessChain %0[%arg0] : !spv.ptr<!spv.rtarray<f32>, Function>, i32
   %2 = spv.Load "Function" %1 ["Volatile"] : f32
   return
 }
@@ -49,7 +49,7 @@ func @access_chain_non_composite() -> () {
   %0 = spv.constant 1: i32
   %1 = spv.Variable : !spv.ptr<f32, Function>
   // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
-  %2 = spv.AccessChain %1[%0] : !spv.ptr<f32, Function>
+  %2 = spv.AccessChain %1[%0] : !spv.ptr<f32, Function>, i32
   return
 }
 
@@ -58,7 +58,34 @@ func @access_chain_non_composite() -> () {
 func @access_chain_no_indices(%index0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
   // expected-error @+1 {{expected at least one index}}
-  %1 = spv.AccessChain %0[] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
+  return
+}
+
+// -----
+
+func @access_chain_missing_comma(%index0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // expected-error @+1 {{expected ','}}
+  %1 = spv.AccessChain %0[%index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function> i32
+  return
+}
+
+// -----
+
+func @access_chain_invalid_indices_types_count(%index0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // expected-error @+1 {{'spv.AccessChain' op indices types' count must be equal to indices info count}}
+  %1 = spv.AccessChain %0[%index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32, i32
+  return
+}
+
+// -----
+
+func @access_chain_missing_indices_type(%index0 : i32) -> () {
+  %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  // expected-error @+1 {{'spv.AccessChain' op indices types' count must be equal to indices info count}}
+  %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32
   return
 }
 
@@ -68,7 +95,7 @@ func @access_chain_invalid_type(%index0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
   %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
   // expected-error @+1 {{expected a pointer to composite type, but provided '!spv.array<4 x !spv.array<4 x f32>>'}}
-  %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>>
+  %2 = spv.AccessChain %1[%index0] : !spv.array<4x!spv.array<4xf32>>, i32
   return
 }
 
@@ -77,7 +104,7 @@ func @access_chain_invalid_type(%index0 : i32) -> () {
 func @access_chain_invalid_index_1(%index0 : i32) -> () {
    %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
   // expected-error @+1 {{expected SSA operand}}
-  %1 = spv.AccessChain %0[%index, 4] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%index, 4] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32, i32
   return
 }
 
@@ -86,7 +113,7 @@ func @access_chain_invalid_index_1(%index0 : i32) -> () {
 func @access_chain_invalid_index_2(%index0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
   // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct}}
-  %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
   return
 }
 
@@ -96,7 +123,7 @@ func @access_chain_invalid_constant_type_1() -> () {
   %0 = std.constant 1: i32
   %1 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
   // expected-error @+1 {{index must be an integer spv.constant to access element of spv.struct, but provided std.constant}}
-  %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
   return
 }
 
@@ -106,7 +133,7 @@ func @access_chain_out_of_bounds() -> () {
   %index0 = "spv.constant"() { value = 12: i32} : () -> i32
   %0 = spv.Variable : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
   // expected-error @+1 {{'spv.AccessChain' op index 12 out of bounds for '!spv.struct<f32, !spv.array<4 x f32>>'}}
-  %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%index0, %index0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Function>, i32, i32
   return
 }
 
@@ -115,7 +142,7 @@ func @access_chain_out_of_bounds() -> () {
 func @access_chain_invalid_accessing_type(%index0 : i32) -> () {
   %0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
   // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 0}}
-  %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+  %1 = spv.AccessChain %0[%index, %index0, %index0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>, i32, i32, i32
   return
 
 // -----

diff  --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index e24ebcd6e656..93df070f0a2a 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -11,7 +11,7 @@ spv.module Logical GLSL450 {
     // CHECK: [[VAR1:%.*]] = spv._address_of @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Input>
     // CHECK-NEXT: spv.AccessChain [[VAR1]][{{.*}}, {{.*}}] : !spv.ptr<!spv.struct<f32, !spv.array<4 x f32>>, Input>
     %1 = spv._address_of @var1 : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
-    %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
+    %2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>, i32, i32
     spv.Return
   }
 }


        


More information about the Mlir-commits mailing list