[Mlir-commits] [mlir] ddc3d51 - [mlir][spirv] Add (InBounds)PtrAccessChain ops

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 18 08:03:35 PDT 2021

Author: Butygin
Date: 2021-08-18T17:59:21+03:00
New Revision: ddc3d51d5880df4253b98b356c96c1a9fea6f971

URL: https://github.com/llvm/llvm-project/commit/ddc3d51d5880df4253b98b356c96c1a9fea6f971
DIFF: https://github.com/llvm/llvm-project/commit/ddc3d51d5880df4253b98b356c96c1a9fea6f971.diff

LOG: [mlir][spirv] Add (InBounds)PtrAccessChain ops

Differential Revision: https://reviews.llvm.org/D108070




diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 47fa1b2b75275..bb32d1b837add 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3194,6 +3194,8 @@ def SPV_OC_OpLoad                      : I32EnumAttrCase<"OpLoad", 61>;
 def SPV_OC_OpStore                     : I32EnumAttrCase<"OpStore", 62>;
 def SPV_OC_OpCopyMemory                : I32EnumAttrCase<"OpCopyMemory", 63>;
 def SPV_OC_OpAccessChain               : I32EnumAttrCase<"OpAccessChain", 65>;
+def SPV_OC_OpPtrAccessChain            : I32EnumAttrCase<"OpPtrAccessChain", 67>;
+def SPV_OC_OpInBoundsPtrAccessChain    : I32EnumAttrCase<"OpInBoundsPtrAccessChain", 70>;
 def SPV_OC_OpDecorate                  : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpMemberDecorate            : I32EnumAttrCase<"OpMemberDecorate", 72>;
 def SPV_OC_OpVectorExtractDynamic      : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
@@ -3340,10 +3342,10 @@ def SPV_OpcodeAttr :
       SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp,
       SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd,
       SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
-      SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
-      SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
-      SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle,
-      SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
+      SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpPtrAccessChain,
+      SPV_OC_OpInBoundsPtrAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
+      SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic,
+      SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
       SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather,
       SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU,
       SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index 77fa63fc492ff..63cd3fe0213da 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -137,6 +137,55 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
 // -----
+def SPV_InBoundsPtrAccessChainOp : SPV_Op<"InBoundsPtrAccessChain", [NoSideEffect]> {
+  let summary = [{
+    Has the same semantics as OpPtrAccessChain, with the addition that the
+    resulting pointer is known to point within the base object.
+  }];
+  let description = [{
+    <!-- End of AutoGen section -->
+    ```
+    access-chain-op ::= ssa-id `=` `spv.InBoundsPtrAccessChain` ssa-use
+                        `[` ssa-use (',' ssa-use)* `]`
+                        `:` pointer-type
+    ```mlir
+    #### Example:
+    ```
+    func @inbounds_ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+      %0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+      ...
+    }
+    ```
+  }];
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Addresses]>
+  ];
+  let arguments = (ins
+    SPV_AnyPtr:$base_ptr,
+    SPV_Integer:$element,
+    Variadic<SPV_Integer>:$indices
+  );
+  let results = (outs
+    SPV_AnyPtr:$result
+  );
+  let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+// -----
 def SPV_LoadOp : SPV_Op<"Load", []> {
   let summary = "Load through a pointer.";
@@ -191,6 +240,78 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
 // -----
+def SPV_PtrAccessChainOp : SPV_Op<"PtrAccessChain", [NoSideEffect]> {
+  let summary = [{
+    Has the same semantics as OpAccessChain, with the addition of the
+    Element operand.
+  }];
+  let description = [{
+    Element is used to do an initial dereference of Base: Base is treated as
+    the address of an element in an array, and a new element address is
+    computed from Base and Element to become the OpAccessChain Base to
+    dereference as per OpAccessChain. This computed Base has the same type
+    as the originating Base.
+    To compute the new element address, Element is treated as a signed count
+    of elements E, relative to the original Base element B, and the address
+    of element B + E is computed using enough precision to avoid overflow
+    and underflow. For objects in the Uniform, StorageBuffer, or
+    PushConstant storage classes, the element's address or location is
+    calculated using a stride, which will be the Base-type's Array Stride if
+    the Base type is decorated with ArrayStride. For all other objects, the
+    implementation calculates the element's address or location.
+    With one exception, undefined behavior results when B + E is not an
+    element in the same array (same innermost array, if array types are
+    nested) as B. The exception being when B + E = L, where L is the length
+    of the array: the address computation for element L is done with the
+    same stride as any other B + E computation that stays within the array.
+    Note: If Base is typed to be a pointer to an array and the desired
+    operation is to select an element of that array, OpAccessChain should be
+    directly used, as its first Index selects the array element.
+    <!-- End of AutoGen section -->
+    ```
+    [access-chain-op ::= ssa-id `=` `spv.PtrAccessChain` ssa-use
+                        `[` ssa-use (',' ssa-use)* `]`
+                        `:` pointer-type
+    ```mlir
+    #### Example:
+    ```
+    func @ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+      %0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+      ...
+    }
+    ```
+  }];
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_Addresses, SPV_C_PhysicalStorageBufferAddresses, SPV_C_VariablePointers, SPV_C_VariablePointersStorageBuffer]>
+  ];
+  let arguments = (ins
+    SPV_AnyPtr:$base_ptr,
+    SPV_Integer:$element,
+    Variadic<SPV_Integer>:$indices
+  );
+  let results = (outs
+    SPV_AnyPtr:$result
+  );
+  let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+// -----
 def SPV_StoreOp : SPV_Op<"Store", []> {
   let summary = "Store through a pointer.";

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c03388d266b5a..fc18fdd78b3cc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1019,37 +1019,41 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
   return success();
+template <typename Op>
+static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
+  printer << Op::getOperationName() << ' ' << op.base_ptr() << '[' << indices
+          << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
-  printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
-          << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
-          << op.indices().getTypes();
+  printAccessChain(op, op.indices(), printer);
-static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
-  SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
-                                accessChainOp.indices().end());
+template <typename Op>
+static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
   auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
                                       indices, accessChainOp.getLoc());
-  if (!resultType) {
+  if (!resultType)
     return failure();
-  }
   auto providedResultType =
-      accessChainOp.getType().dyn_cast<spirv::PointerType>();
-  if (!providedResultType) {
+      accessChainOp.getType().template dyn_cast<spirv::PointerType>();
+  if (!providedResultType)
     return accessChainOp.emitOpError(
                "result type must be a pointer, but provided")
            << providedResultType;
-  }
-  if (resultType != providedResultType) {
+  if (resultType != providedResultType)
     return accessChainOp.emitOpError("invalid result type: expected ")
            << resultType << ", but provided " << providedResultType;
-  }
   return success();
+static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
+  return verifyAccessChain(accessChainOp, accessChainOp.indices());
 // spv.mlir.addressof
@@ -3770,6 +3774,109 @@ static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
   return success();
+static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
+                                             OpAsmParser &parser,
+                                             OperationState &state) {
+  OpAsmParser::OperandType ptrInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
+  Type type;
+  auto loc = parser.getCurrentLocation();
+  SmallVector<Type, 4> indicesTypes;
+  if (parser.parseOperand(ptrInfo) ||
+      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+      parser.parseColonType(type) ||
+      parser.resolveOperand(ptrInfo, type, state.operands))
+    return failure();
+  // Check that the provided indices list is not empty before parsing their
+  // type list.
+  if (indicesInfo.empty())
+    return emitError(state.location) << opName << " expected element";
+  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
+    return failure();
+  // Check that the indices types list is not empty and that it has a one-to-one
+  // mapping to the provided indices.
+  if (indicesTypes.size() != indicesInfo.size())
+    return emitError(state.location)
+           << opName
+           << " indices types' count must be equal to indices info count";
+  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
+    return failure();
+  auto resultType = getElementPtrType(
+      type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
+  if (!resultType)
+    return failure();
+  state.addTypes(resultType);
+  return success();
+template <typename Op>
+static auto concatElemAndIndices(Op op) {
+  SmallVector<Value> ret(op.indices().size() + 1);
+  ret[0] = op.element();
+  llvm::copy(op.indices(), ret.begin() + 1);
+  return ret;
+// spv.InBoundsPtrAccessChainOp
+void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
+                                            OperationState &state,
+                                            Value basePtr, Value element,
+                                            ValueRange indices) {
+  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+  assert(type && "Unable to deduce return type based on basePtr and indices");
+  build(builder, state, type, basePtr, element, indices);
+static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser,
+                                                 OperationState &state) {
+  return parsePtrAccessChainOpImpl(
+      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
+static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) {
+  printAccessChain(op, concatElemAndIndices(op), printer);
+static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) {
+  return verifyAccessChain(accessChainOp, accessChainOp.indices());
+// spv.PtrAccessChainOp
+void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
+                                    Value basePtr, Value element,
+                                    ValueRange indices) {
+  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+  assert(type && "Unable to deduce return type based on basePtr and indices");
+  build(builder, state, type, basePtr, element, indices);
+static ParseResult parsePtrAccessChainOp(OpAsmParser &parser,
+                                         OperationState &state) {
+  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
+                                   parser, state);
+static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) {
+  printAccessChain(op, concatElemAndIndices(op), printer);
+static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
+  return verifyAccessChain(accessChainOp, accessChainOp.indices());
 namespace mlir {
 namespace spirv {

diff  --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 62a43b0f8f28e..1ebc62286bd25 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -628,3 +628,33 @@ func @copy_memory_print_maa() {
+// -----
+// spv.PtrAccessChain
+// CHECK-LABEL:   func @ptr_access_chain1(
+// CHECK-SAME:    %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
+// CHECK-SAME:    %[[ARG1:.*]]: i64)
+// CHECK: spv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
+func @ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+  %0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+  return
+// -----
+// spv.InBoundsPtrAccessChain
+// CHECK-LABEL:   func @inbounds_ptr_access_chain1(
+// CHECK-SAME:    %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
+// CHECK-SAME:    %[[ARG1:.*]]: i64)
+// CHECK: spv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
+func @inbounds_ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+  %0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+  return


More information about the Mlir-commits mailing list