[Mlir-commits] [mlir] ddc3d51 - [mlir][spirv] Add (InBounds)PtrAccessChain ops
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 18 08:03:35 PDT 2021
Author: Butygin
Date: 2021-08-18T17:59:21+03:00
New Revision: ddc3d51d5880df4253b98b356c96c1a9fea6f971
URL: https://github.com/llvm/llvm-project/commit/ddc3d51d5880df4253b98b356c96c1a9fea6f971
DIFF: https://github.com/llvm/llvm-project/commit/ddc3d51d5880df4253b98b356c96c1a9fea6f971.diff
LOG: [mlir][spirv] Add (InBounds)PtrAccessChain ops
Differential Revision: https://reviews.llvm.org/D108070
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 47fa1b2b75275..bb32d1b837add 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3194,6 +3194,8 @@ def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>;
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
+def SPV_OC_OpPtrAccessChain : I32EnumAttrCase<"OpPtrAccessChain", 67>;
+def SPV_OC_OpInBoundsPtrAccessChain : I32EnumAttrCase<"OpInBoundsPtrAccessChain", 70>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
def SPV_OC_OpVectorExtractDynamic : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
@@ -3340,10 +3342,10 @@ def SPV_OpcodeAttr :
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpSpecConstantOp,
SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd,
SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
- SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
- SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
- SPV_OC_OpVectorInsertDynamic, SPV_OC_OpVectorShuffle,
- SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
+ SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpPtrAccessChain,
+ SPV_OC_OpInBoundsPtrAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
+ SPV_OC_OpVectorExtractDynamic, SPV_OC_OpVectorInsertDynamic,
+ SPV_OC_OpVectorShuffle, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract,
SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpImageDrefGather,
SPV_OC_OpImage, SPV_OC_OpImageQuerySize, SPV_OC_OpConvertFToU,
SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index 77fa63fc492ff..63cd3fe0213da 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -137,6 +137,55 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> {
// -----
+def SPV_InBoundsPtrAccessChainOp : SPV_Op<"InBoundsPtrAccessChain", [NoSideEffect]> {
+ let summary = [{
+ Has the same semantics as OpPtrAccessChain, with the addition that the
+ resulting pointer is known to point within the base object.
+ }];
+
+ let description = [{
+
+
+ <!-- End of AutoGen section -->
+
+ ```
+ access-chain-op ::= ssa-id `=` `spv.InBoundsPtrAccessChain` ssa-use
+ `[` ssa-use (',' ssa-use)* `]`
+ `:` pointer-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ func @inbounds_ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+ %0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+ ...
+ }
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[]>,
+ Capability<[SPV_C_Addresses]>
+ ];
+
+ let arguments = (ins
+ SPV_AnyPtr:$base_ptr,
+ SPV_Integer:$element,
+ Variadic<SPV_Integer>:$indices
+ );
+
+ let results = (outs
+ SPV_AnyPtr:$result
+ );
+
+ let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+}
+
+// -----
+
def SPV_LoadOp : SPV_Op<"Load", []> {
let summary = "Load through a pointer.";
@@ -191,6 +240,78 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
// -----
+def SPV_PtrAccessChainOp : SPV_Op<"PtrAccessChain", [NoSideEffect]> {
+ let summary = [{
+ Has the same semantics as OpAccessChain, with the addition of the
+ Element operand.
+ }];
+
+ let description = [{
+ Element is used to do an initial dereference of Base: Base is treated as
+ the address of an element in an array, and a new element address is
+ computed from Base and Element to become the OpAccessChain Base to
+ dereference as per OpAccessChain. This computed Base has the same type
+ as the originating Base.
+
+ To compute the new element address, Element is treated as a signed count
+ of elements E, relative to the original Base element B, and the address
+ of element B + E is computed using enough precision to avoid overflow
+ and underflow. For objects in the Uniform, StorageBuffer, or
+ PushConstant storage classes, the element's address or location is
+ calculated using a stride, which will be the Base-type's Array Stride if
+ the Base type is decorated with ArrayStride. For all other objects, the
+ implementation calculates the element's address or location.
+
+ With one exception, undefined behavior results when B + E is not an
+ element in the same array (same innermost array, if array types are
+ nested) as B. The exception being when B + E = L, where L is the length
+ of the array: the address computation for element L is done with the
+ same stride as any other B + E computation that stays within the array.
+
+ Note: If Base is typed to be a pointer to an array and the desired
+ operation is to select an element of that array, OpAccessChain should be
+ directly used, as its first Index selects the array element.
+
+ <!-- End of AutoGen section -->
+
+ ```
+ [access-chain-op ::= ssa-id `=` `spv.PtrAccessChain` ssa-use
+ `[` ssa-use (',' ssa-use)* `]`
+ `:` pointer-type
+ ```mlir
+
+ #### Example:
+
+ ```
+ func @ptr_access_chain(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+ %0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+ ...
+ }
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[]>,
+ Capability<[SPV_C_Addresses, SPV_C_PhysicalStorageBufferAddresses, SPV_C_VariablePointers, SPV_C_VariablePointersStorageBuffer]>
+ ];
+
+ let arguments = (ins
+ SPV_AnyPtr:$base_ptr,
+ SPV_Integer:$element,
+ Variadic<SPV_Integer>:$indices
+ );
+
+ let results = (outs
+ SPV_AnyPtr:$result
+ );
+
+ let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+}
+
+// -----
+
def SPV_StoreOp : SPV_Op<"Store", []> {
let summary = "Store through a pointer.";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c03388d266b5a..fc18fdd78b3cc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1019,37 +1019,41 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
return success();
}
+template <typename Op>
+static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
+ printer << Op::getOperationName() << ' ' << op.base_ptr() << '[' << indices
+ << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
+}
+
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
- printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
- << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
- << op.indices().getTypes();
+ printAccessChain(op, op.indices(), printer);
}
-static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
- SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
- accessChainOp.indices().end());
+template <typename Op>
+static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
indices, accessChainOp.getLoc());
- if (!resultType) {
+ if (!resultType)
return failure();
- }
auto providedResultType =
- accessChainOp.getType().dyn_cast<spirv::PointerType>();
- if (!providedResultType) {
+ accessChainOp.getType().template dyn_cast<spirv::PointerType>();
+ if (!providedResultType)
return accessChainOp.emitOpError(
"result type must be a pointer, but provided")
<< providedResultType;
- }
- if (resultType != providedResultType) {
+ if (resultType != providedResultType)
return accessChainOp.emitOpError("invalid result type: expected ")
<< resultType << ", but provided " << providedResultType;
- }
return success();
}
+static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
+ return verifyAccessChain(accessChainOp, accessChainOp.indices());
+}
+
//===----------------------------------------------------------------------===//
// spv.mlir.addressof
//===----------------------------------------------------------------------===//
@@ -3770,6 +3774,109 @@ static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) {
return success();
}
+static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
+ OpAsmParser &parser,
+ OperationState &state) {
+ OpAsmParser::OperandType ptrInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
+ Type type;
+ auto loc = parser.getCurrentLocation();
+ SmallVector<Type, 4> indicesTypes;
+
+ if (parser.parseOperand(ptrInfo) ||
+ parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(ptrInfo, type, state.operands))
+ return failure();
+
+ // Check that the provided indices list is not empty before parsing their
+ // type list.
+ if (indicesInfo.empty())
+ return emitError(state.location) << opName << " expected element";
+
+ if (parser.parseComma() || parser.parseTypeList(indicesTypes))
+ return failure();
+
+ // Check that the indices types list is not empty and that it has a one-to-one
+ // mapping to the provided indices.
+ if (indicesTypes.size() != indicesInfo.size())
+ return emitError(state.location)
+ << opName
+ << " indices types' count must be equal to indices info count";
+
+ if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
+ return failure();
+
+ auto resultType = getElementPtrType(
+ type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
+ if (!resultType)
+ return failure();
+
+ state.addTypes(resultType);
+ return success();
+}
+
+template <typename Op>
+static auto concatElemAndIndices(Op op) {
+ SmallVector<Value> ret(op.indices().size() + 1);
+ ret[0] = op.element();
+ llvm::copy(op.indices(), ret.begin() + 1);
+ return ret;
+}
+
+//===----------------------------------------------------------------------===//
+// spv.InBoundsPtrAccessChainOp
+//===----------------------------------------------------------------------===//
+
+void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
+ OperationState &state,
+ Value basePtr, Value element,
+ ValueRange indices) {
+ auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+ assert(type && "Unable to deduce return type based on basePtr and indices");
+ build(builder, state, type, basePtr, element, indices);
+}
+
+static ParseResult parseInBoundsPtrAccessChainOp(OpAsmParser &parser,
+ OperationState &state) {
+ return parsePtrAccessChainOpImpl(
+ spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
+}
+
+static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) {
+ printAccessChain(op, concatElemAndIndices(op), printer);
+}
+
+static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) {
+ return verifyAccessChain(accessChainOp, accessChainOp.indices());
+}
+
+//===----------------------------------------------------------------------===//
+// spv.PtrAccessChainOp
+//===----------------------------------------------------------------------===//
+
+void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
+ Value basePtr, Value element,
+ ValueRange indices) {
+ auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+ assert(type && "Unable to deduce return type based on basePtr and indices");
+ build(builder, state, type, basePtr, element, indices);
+}
+
+static ParseResult parsePtrAccessChainOp(OpAsmParser &parser,
+ OperationState &state) {
+ return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
+ parser, state);
+}
+
+static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) {
+ printAccessChain(op, concatElemAndIndices(op), printer);
+}
+
+static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) {
+ return verifyAccessChain(accessChainOp, accessChainOp.indices());
+}
+
namespace mlir {
namespace spirv {
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 62a43b0f8f28e..1ebc62286bd25 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -628,3 +628,33 @@ func @copy_memory_print_maa() {
spv.Return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.PtrAccessChain
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @ptr_access_chain1(
+// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
+// CHECK-SAME: %[[ARG1:.*]]: i64)
+// CHECK: spv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
+func @ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+ %0 = spv.PtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.InBoundsPtrAccessChain
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @inbounds_ptr_access_chain1(
+// CHECK-SAME: %[[ARG0:.*]]: !spv.ptr<f32, CrossWorkgroup>,
+// CHECK-SAME: %[[ARG1:.*]]: i64)
+// CHECK: spv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spv.ptr<f32, CrossWorkgroup>, i64
+func @inbounds_ptr_access_chain1(%arg0: !spv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
+ %0 = spv.InBoundsPtrAccessChain %arg0[%arg1] : !spv.ptr<f32, CrossWorkgroup>, i64
+ return
+}
More information about the Mlir-commits
mailing list