[Mlir-commits] [mlir] [mlir][spirv] Use assemblyFormat to define {InBound}PtrAccessChainOp assembly (PR #116943)

Yadong Chen llvmlistbot at llvm.org
Wed Nov 20 01:37:25 PST 2024


https://github.com/hahacyd created https://github.com/llvm/llvm-project/pull/116943

see #73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat. Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format

>From 2496ea48cd1063648d1f8ee7dd7bc40ba1911cc1 Mon Sep 17 00:00:00 2001
From: Chen Yadong <cyd.matt at qq.com>
Date: Wed, 20 Nov 2024 17:21:37 +0800
Subject: [PATCH] [mlir][spirv] Use assemblyFormat to define PtrAccessChainOp
 and InBoundPtrAccessChainOp assembly

see #73359

Declarative assemblyFormat ODS is more concise and requires less boilerplate than filling out cpp interfaces.

Changes:
updates the PtrAccessChainOp and InBoundPtrAccessChainOp defined in SPIRVMemoryOps.td to use assemblyFormat.
Removes part print/parse from MemoryOps.cpp which is now generated by assemblyFormat
Updates tests to updated format
---
 .../mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td   | 12 ++++
 mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp       | 70 -------------------
 mlir/test/Dialect/SPIRV/IR/memory-ops.mlir    |  4 +-
 3 files changed, 14 insertions(+), 72 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index de7be3f21f3b17..878bfaa21e606b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -183,6 +183,12 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----
@@ -311,6 +317,12 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
   );
 
   let builders = [OpBuilder<(ins "Value":$basePtr, "Value":$element, "ValueRange":$indices)>];
+
+  let hasCustomAssemblyFormat = 0;
+
+  let assemblyFormat = [{
+    $base_ptr `[` $element ($indices^)? `]` attr-dict `:` type($base_ptr) `,` type($element) (`,` type($indices)^)? `->` type($result)
+  }];
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 154e955d6057a8..5ae27e5d82bd73 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -543,56 +543,6 @@ LogicalResult CopyMemoryOp::verify() {
   return verifySourceMemoryAccessAttribute(*this);
 }
 
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
-                                             OpAsmParser &parser,
-                                             OperationState &state) {
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-  SmallVector<Type, 4> indicesTypes;
-
-  if (parser.parseOperand(ptrInfo) ||
-      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, state.operands))
-    return failure();
-
-  // Check that the provided indices list is not empty before parsing their
-  // type list.
-  if (indicesInfo.empty())
-    return emitError(state.location) << opName << " expected element";
-
-  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
-    return failure();
-
-  // Check that the indices types list is not empty and that it has a one-to-one
-  // mapping to the provided indices.
-  if (indicesTypes.size() != indicesInfo.size())
-    return emitError(state.location)
-           << opName
-           << " indices types' count must be equal to indices info count";
-
-  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
-    return failure();
-
-  auto resultType = getElementPtrType(
-      type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
-  if (!resultType)
-    return failure();
-
-  state.addTypes(resultType);
-  return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
-  SmallVector<Value> ret(op.getIndices().size() + 1);
-  ret[0] = op.getElement();
-  llvm::copy(op.getIndices(), ret.begin() + 1);
-  return ret;
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.InBoundsPtrAccessChainOp
 //===----------------------------------------------------------------------===//
@@ -605,16 +555,6 @@ void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
-                                            OperationState &result) {
-  return parsePtrAccessChainOpImpl(
-      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult InBoundsPtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }
@@ -630,16 +570,6 @@ void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, type, basePtr, element, indices);
 }
 
-ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
-                                    OperationState &result) {
-  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
-                                   parser, result);
-}
-
-void PtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
 LogicalResult PtrAccessChainOp::verify() {
   return verifyAccessChain(*this, getIndices());
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 12bfee9fb65119..5aef6135afd97e 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -699,7 +699,7 @@ func.func @copy_memory_print_maa() {
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.PtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.PtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }
 
@@ -714,6 +714,6 @@ func.func @ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64
 // CHECK-SAME:    %[[ARG1:.*]]: i64)
 // CHECK: spirv.InBoundsPtrAccessChain %[[ARG0]][%[[ARG1]]] : !spirv.ptr<f32, CrossWorkgroup>, i64
 func.func @inbounds_ptr_access_chain1(%arg0: !spirv.ptr<f32, CrossWorkgroup>, %arg1 : i64) -> () {
-  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64
+  %0 = spirv.InBoundsPtrAccessChain %arg0[%arg1] : !spirv.ptr<f32, CrossWorkgroup>, i64 -> !spirv.ptr<f32, CrossWorkgroup>
   return
 }



More information about the Mlir-commits mailing list