[Mlir-commits] [mlir] a5506a3 - [mlir][spirv] Use assemblyFormat to define {InBound}PtrAccessChainOp assembly (#116943)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 25 07:11:11 PST 2024


Author: Yadong Chen
Date: 2024-11-25T10:11:07-05:00
New Revision: a5506a39e0ae8de77136334659b526e5f224850d

URL: https://github.com/llvm/llvm-project/commit/a5506a39e0ae8de77136334659b526e5f224850d
DIFF: https://github.com/llvm/llvm-project/commit/a5506a39e0ae8de77136334659b526e5f224850d.diff

LOG: [mlir][spirv] Use assemblyFormat to define {InBound}PtrAccessChainOp assembly (#116943)

Declarative assemblyFormat ODS is more concise and requires less
boilerplate than filling out cpp interfaces.

Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in
SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from
MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format

Issue: #73359

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
    mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
    mlir/test/Dialect/SPIRV/IR/memory-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index de7be3f21f3b17..aad50175546a5d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -46,18 +46,13 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
     - must be an OpConstant when indexing into a structure.
 
     <!-- End of AutoGen section -->
-    ```
-    access-chain-op ::= ssa-id `=` `spirv.AccessChain` ssa-use
-                        `[` ssa-use (',' ssa-use)* `]`
-                        `:` pointer-type
-    ```
 
     #### Example:
 
     ```mlir
     %0 = "spirv.Constant"() { value = 1: i32} : () -> i32
     %1 = spirv.Variable : !spirv.ptr<!spirv.struct<f32, !spirv.array<4xf32>>, Function>
-    %2 = spirv.AccessChain %1[%0] : !spirv.ptr<!spirv.struct<f32, !spirv.array<4xf32>>, Function>
+    %2 = spirv.AccessChain %1[%0] : !spirv.ptr<!spirv.struct<f32, !spirv.array<4xf32>>, Function> -> !spirv.ptr<!spirv.array<4xf32>, Function>
     %3 = spirv.Load "Function" %2 ["Volatile"] : !spirv.array<4xf32>
     ```
   }];
@@ -149,17 +144,11 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
 
     <!-- End of AutoGen section -->
 
-    ```
-    access-chain-op ::= ssa-id `=` `spirv.InBoundsPtrAccessChain` ssa-use
-                        `[` ssa-use (',' ssa-use)* `]`
-                        `:` pointer-type
-    ```
-
     #### Example:
 
     ```mlir
     func @inbounds_ptr_access_chain(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-      %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+      %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
       ...
     }
     ```
@@ -183,6 +172,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----
@@ -275,17 +270,11 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
 
     <!-- End of AutoGen section -->
 
-    ```
-    [access-chain-op ::= ssa-id `=` `spirv.PtrAccessChain` ssa-use
-                        `[` ssa-use (',' ssa-use)* `]`
-                        `:` pointer-type
-    ```
-
     #### Example:
 
     ```mlir
     func @ptr_access_chain(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-      %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+      %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
       ...
     }
     ```
@@ -311,6 +300,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 154e955d6057a8..5ae27e5d82bd73 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
   return verifySourceMemoryAccessAttribute(*this);
 }
 
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
-                                             OpAsmParser &parser,
-                                             OperationState &state) {
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-  SmallVector<Type, 4> indicesTypes;
-
-  if (parser.parseOperand(ptrInfo) ||
-      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, state.operands))
-    return failure();
-
-  // Check that the provided indices list is not empty before parsing their
-  // type list.
-  if (indicesInfo.empty())
-    return emitError(state.location) << opName << " expected element";
-
-  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
-    return failure();
-
-  // Check that the indices types list is not empty and that it has a one-to-one
-  // mapping to the provided indices.
-  if (indicesTypes.size() != indicesInfo.size())
-    return emitError(state.location)
-           << opName
-           << " indices types' count must be equal to indices info count";
-
-  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
-    return failure();
-
-  auto resultType = getElementPtrType(
-      type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
-  if (!resultType)
-    return failure();
-
-  state.addTypes(resultType);
-  return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
-  SmallVector<Value> ret(op.getIndices().size() + 1);
-  ret[0] = op.getElement();
-  llvm::copy(op.getIndices(), ret.begin() + 1);
-  return ret;
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.InBoundsPtrAccessChainOp
 //===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
-                                            OperationState &result) {
-  return parsePtrAccessChainOpImpl(
-      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult InBoundsPtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
-                                    OperationState &result) {
-  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
-                                   parser, result);
-}
-
-void PtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult PtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 12bfee9fb65119..5aef6135afd97e 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }
 
@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }


        


More information about the Mlir-commits mailing list