[Mlir-commits] [mlir] [mlir][spirv] Use assemblyFormat to define {InBound}PtrAccessChainOp assembly (PR #116943)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 20 01:38:03 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Yadong Chen (hahacyd)
<details>
<summary>Changes</summary>
see #<!-- -->73359
Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.
Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format
---
Full diff: https://github.com/llvm/llvm-project/pull/116943.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td (+12)
- (modified) mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp (-70)
- (modified) mlir/test/Dialect/SPIRV/IR/memory-ops.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index de7be3f21f3b17..878bfaa21e606b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -183,6 +183,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+ let hasCustomAssemblyFormat = 0;
+
+ let assemblyFormat = [{
+ $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+ }];
}
// -----
@@ -311,6 +317,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
);
let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+ let hasCustomAssemblyFormat = 0;
+
+ let assemblyFormat = [{
+ $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+ }];
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 154e955d6057a8..5ae27e5d82bd73 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
return verifySourceMemoryAccessAttribute(*this);
}
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
- OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::UnresolvedOperand ptrInfo;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
- Type type;
- auto loc = parser.getCurrentLocation();
- SmallVector<Type, 4> indicesTypes;
-
- if (parser.parseOperand(ptrInfo) ||
- parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(ptrInfo, type, state.operands))
- return failure();
-
- // Check that the provided indices list is not empty before parsing their
- // type list.
- if (indicesInfo.empty())
- return emitError(state.location) << opName << " expected element";
-
- if (parser.parseComma() || parser.parseTypeList(indicesTypes))
- return failure();
-
- // Check that the indices types list is not empty and that it has a one-to-one
- // mapping to the provided indices.
- if (indicesTypes.size() != indicesInfo.size())
- return emitError(state.location)
- << opName
- << " indices types' count must be equal to indices info count";
-
- if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
- return failure();
-
- auto resultType = getElementPtrType(
- type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
- if (!resultType)
- return failure();
-
- state.addTypes(resultType);
- return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
- SmallVector<Value> ret(op.getIndices().size() + 1);
- ret[0] = op.getElement();
- llvm::copy(op.getIndices(), ret.begin() + 1);
- return ret;
-}
-
//===----------------------------------------------------------------------===//
// spirv.InBoundsPtrAccessChainOp
//===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, element, indices);
}
-ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(
- spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
LogicalResult InBoundsPtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, type, basePtr, element, indices);
}
-ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
- parser, result);
-}
-
-void PtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
LogicalResult PtrAccessChainOp::verify() {
return verifyAccessChain(*this, getIndices());
}
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 12bfee9fb65119..5aef6135afd97e 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
- %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+ %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
return
}
@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
// CHECK-SAME: %[[ARG1:.*]]: i64)
// CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
- %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+ %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/116943
More information about the Mlir-commits
mailing list