[flang-commits] [flang] d790a21 - [mlir] Add getArgOperandsMutable method to CallOpInterface

Martin Erhart via flang-commits flang-commits at lists.llvm.org
Wed Aug 2 01:35:26 PDT 2023


Author: Martin Erhart
Date: 2023-08-02T08:08:18Z
New Revision: d790a217a746ff63190d30669674fce1bc0a4723

URL: https://github.com/llvm/llvm-project/commit/d790a217a746ff63190d30669674fce1bc0a4723
DIFF: https://github.com/llvm/llvm-project/commit/d790a217a746ff63190d30669674fce1bc0a4723.diff

LOG: [mlir] Add getArgOperandsMutable method to CallOpInterface

Add a method to the CallOpInterface to get a mutable operand range over
the function arguments.  This allows to add, remove, or change the type
of call arguments in a generic manner without having to assume that the
argument operand range is at the end of the operand list, or having to
type switch on all supported concrete operation kinds.

Alternatively, a new OpInterface could be added which inherits from
CallOpInterface and appends it with the mutable variants of the base
interface.

There will be two users of this new function in the beginning:
(1) A few passes in the Arc dialect in CIRCT already use a downstream
implementation of the alternative case mentioned above: https://github.com/llvm/circt/blob/main/include/circt/Dialect/Arc/ArcInterfaces.td#L15
(2) The BufferDeallocation pass will be modified to be able to pass
ownership of memrefs to called private functions if the caller does not
need the memref anymore by appending the function argument list with a
boolean value per memref, thus enabling earlier deallocation of the
memref which can lead to lower peak memory usage.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D156675

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/Dialect/Func/IR/FuncOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/include/mlir/Interfaces/CallInterfaces.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8b05c973606078..f07e8009cf2c24 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2347,6 +2347,12 @@ def fir_CallOp : fir_Op<"call",
       return {arg_operand_begin() + 1, arg_operand_end()};
     }
 
+    mlir::MutableOperandRange getArgOperandsMutable() {
+      if ((*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
+        return getArgsMutable();
+      return mlir::MutableOperandRange(*this, 1, getArgs().size() - 1);
+    }
+
     operand_iterator arg_operand_begin() { return operand_begin(); }
     operand_iterator arg_operand_end() { return operand_end(); }
 

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index e84151884ad44b..330567792412a5 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
 
+/// Get the argument operands to the called function as a mutable range, this is
+/// required by the call interface.
+MutableOperandRange GenericCallOp::getArgOperandsMutable() {
+  return getInputsMutable();
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index c2a99aa2921b8f..ebd4344ff3cc72 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
 
+/// Get the argument operands to the called function as a mutable range, this is
+/// required by the call interface.
+MutableOperandRange GenericCallOp::getArgOperandsMutable() {
+  return getInputsMutable();
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index c2a99aa2921b8f..ebd4344ff3cc72 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
 
+/// Get the argument operands to the called function as a mutable range, this is
+/// required by the call interface.
+MutableOperandRange GenericCallOp::getArgOperandsMutable() {
+  return getInputsMutable();
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 1b77f8ce6d8a4b..35aaa435644d24 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -377,6 +377,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
 
+/// Get the argument operands to the called function as a mutable range, this is
+/// required by the call interface.
+MutableOperandRange GenericCallOp::getArgOperandsMutable() {
+  return getInputsMutable();
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 6ca4925a452345..54ad3c63189c8d 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -264,6 +264,10 @@ def Async_CallOp : Async_Op<"call",
       return {arg_operand_begin(), arg_operand_end()};
     }
 
+    MutableOperandRange getArgOperandsMutable() {
+      return getOperandsMutable();
+    }
+
     operand_iterator arg_operand_begin() { return operand_begin(); }
     operand_iterator arg_operand_end() { return operand_end(); }
 

diff  --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index c5ecab4df9ddf2..7897cf4dcdb043 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -83,6 +83,10 @@ def CallOp : Func_Op<"call",
       return {arg_operand_begin(), arg_operand_end()};
     }
 
+    MutableOperandRange getArgOperandsMutable() {
+      return getOperandsMutable();
+    }
+
     operand_iterator arg_operand_begin() { return operand_begin(); }
     operand_iterator arg_operand_end() { return operand_end(); }
 
@@ -152,6 +156,10 @@ def CallIndirectOp : Func_Op<"call_indirect", [
       return {arg_operand_begin(), arg_operand_end()};
     }
 
+    MutableOperandRange getArgOperandsMutable() {
+      return getCalleeOperandsMutable();
+    }
+
     operand_iterator arg_operand_begin() { return ++operand_begin(); }
     operand_iterator arg_operand_end() { return operand_end(); }
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6fb422fea5b6d8..dfc80afc8ff9fa 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -616,7 +616,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
   }];
 
   dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
-                  Variadic<LLVM_Type>,
+                  Variadic<LLVM_Type>:$callee_operands,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights);

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 7c37416b09adbe..8a30205ee17680 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -632,6 +632,10 @@ def IncludeOp : TransformDialectOp<"include",
     ::mlir::Operation::operand_range getArgOperands() {
       return getOperands();
     }
+
+    ::mlir::MutableOperandRange getArgOperandsMutable() {
+      return getOperandsMutable();
+    }
   }];
 }
 

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 328b3d594325ac..499ccefad2bcce 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -55,6 +55,11 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
       }],
       "::mlir::Operation::operand_range", "getArgOperands"
     >,
+    InterfaceMethod<[{
+        Returns the operands within this call that are used as arguments to the
+        callee as a mutable range.
+      }],
+      "::mlir::MutableOperandRange", "getArgOperandsMutable">,
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9b76358399a76c..eaad4c7a3af5fb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1003,6 +1003,11 @@ Operation::operand_range CallOp::getArgOperands() {
   return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
+MutableOperandRange CallOp::getArgOperandsMutable() {
+  return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
+                             getCalleeOperands().size());
+}
+
 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");
@@ -1237,6 +1242,11 @@ Operation::operand_range InvokeOp::getArgOperands() {
   return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
 
+MutableOperandRange InvokeOp::getArgOperandsMutable() {
+  return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
+                             getCalleeOperands().size());
+}
+
 LogicalResult InvokeOp::verify() {
   if (getNumResults() > 1)
     return emitOpError("must have 0 or 1 result");

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index e169cb5b65322e..081f8b601f41f0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -208,6 +208,10 @@ Operation::operand_range FunctionCallOp::getArgOperands() {
   return getArguments();
 }
 
+MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
+  return getArgumentsMutable();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.loop
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 1edd4aa40da043..485d21823eb67a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1263,6 +1263,10 @@ Operation::operand_range TestCallAndStoreOp::getArgOperands() {
   return getCalleeOperands();
 }
 
+MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
+  return getCalleeOperandsMutable();
+}
+
 void TestStoreWithARegion::getSuccessorRegions(
     std::optional<unsigned> index, ArrayRef<Attribute> operands,
     SmallVectorImpl<RegionSuccessor> &regions) {


        


More information about the flang-commits mailing list