[Mlir-commits] [mlir] a2ab6a5 - [mlir][CallOpInterface] Add `setCalleeFromCallable` method
Whitney Tsang
llvmlistbot at llvm.org
Mon May 8 06:07:46 PDT 2023
Author: Whitney Tsang
Date: 2023-05-08T06:07:10-07:00
New Revision: a2ab6a5e2b8d4e10ce29b24db7d6ae18c9acbec1
URL: https://github.com/llvm/llvm-project/commit/a2ab6a5e2b8d4e10ce29b24db7d6ae18c9acbec1
DIFF: https://github.com/llvm/llvm-project/commit/a2ab6a5e2b8d4e10ce29b24db7d6ae18c9acbec1.diff
LOG: [mlir][CallOpInterface] Add `setCalleeFromCallable` method
Currently `CallOpInterface` has a method `getCallableForCallee` to have a consistent way to get the callee from an operation with `CallOpInterface`, but missing a consistent way to set a callee for an operation with `CallOpInterface`.
A set callee method is useful for transformations that operate on `CallOpInterface`, and change the callee, e.g., a pass that specialize function, which clone the callee, and change the `CallOpInterface`'s callee to the cloned version. Without such method, transformation would need to understand the implementation for every operations with `CallOpInterface`, and have a type switch to handle them.
This review adds a method to set callee for operation with `CallOpInterface`.
Reviewed By: gysit, zero9178o
Differential Revision: https://reviews.llvm.org/D149763
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
mlir/docs/Interfaces.md
mlir/docs/Tutorials/Toy/Ch-4.md
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/Func/IR/FuncOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Interfaces/CallInterfaces.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 2bc4ec0401d52..0e07e6fcaac9d 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2357,6 +2357,14 @@ def fir_CallOp : fir_Op<"call",
return calling;
return getOperand(0);
}
+
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+ if (auto calling =
+ (*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
+ (*this)->setAttr(getCalleeAttrName(), callee.get<mlir::SymbolRefAttr>());
+ setOperand(0, callee.get<mlir::Value>());
+ }
}];
}
diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index b51adec4fc4f3..a299feb54cba7 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -728,6 +728,7 @@ interface section goes as follows:
* `CallOpInterface` - Used to represent operations like 'call'
- `CallInterfaceCallable getCallableForCallee()`
+ - `void setCalleeFromCallable(CallInterfaceCallable)`
* `CallableOpInterface` - Used to represent the target callee of call.
- `Region * getCallableRegion()`
- `ArrayRef<Type> getCallableResults()`
diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index f462274fa592e..9ca9706644fa5 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -189,6 +189,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("callee");
}
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 75a517159a6d2..d533e5805081f 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 98c8eb5dd7989..4f0326682fbd7 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 98c8eb5dd7989..4f0326682fbd7 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 5fcb0be36c8aa..643240333d92e 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -367,6 +367,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 30147b8b6a309..9824238e7d933 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -271,6 +271,11 @@ def Async_CallOp : Async_Op<"call",
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+ }
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 4204bc576970f..fb206f1be8175 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -91,6 +91,11 @@ def CallOp : Func_Op<"call",
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
+
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+ }
}];
let assemblyFormat = [{
@@ -153,6 +158,11 @@ def CallIndirectOp : Func_Op<"call_indirect", [
/// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() { return getCallee(); }
+
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(CallInterfaceCallable callee) {
+ setOperand(0, callee.get<Value>());
+ }
}];
let hasCanonicalizeMethod = 1;
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 53c1c0af54c0f..8154835546bb6 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -372,6 +372,10 @@ def IncludeOp : TransformDialectOp<"include",
return getTarget();
}
+ void setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+ setTargetAttr(callee.get<SymbolRefAttr>());
+ }
+
::mlir::Operation::operand_range getArgOperands() {
return getOperands();
}
diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index cd37222cbc270..328b3d594325a 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -40,6 +40,15 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
}],
"::mlir::CallInterfaceCallable", "getCallableForCallee"
>,
+ InterfaceMethod<[{
+ Sets the callee of this call-like operation. A `callee` is either a
+ reference to a symbol, via SymbolRefAttr, or a reference to a defined
+ SSA value. The type of the `callee` is expected to be the same as the
+ return type of `getCallableForCallee`, e.g., `callee` should be
+ SymbolRefAttr for `func.call`.
+ }],
+ "void", "setCalleeFromCallable", (ins "::mlir::CallInterfaceCallable":$callee)
+ >,
InterfaceMethod<[{
Returns the operands within this call that are used as arguments to the
callee.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9595c18519520..5380ba0666197 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -933,6 +933,16 @@ CallInterfaceCallable CallOp::getCallableForCallee() {
return getOperand(0);
}
+void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ // Direct call.
+ if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
+ auto symRef = callee.get<SymbolRefAttr>();
+ return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
+ }
+ // Indirect call, callee Value is the first operand.
+ return setOperand(0, callee.get<Value>());
+}
+
Operation::operand_range CallOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
@@ -1157,6 +1167,16 @@ CallInterfaceCallable InvokeOp::getCallableForCallee() {
return getOperand(0);
}
+void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ // Direct call.
+ if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
+ auto symRef = callee.get<SymbolRefAttr>();
+ return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
+ }
+ // Indirect call, callee Value is the first operand.
+ return setOperand(0, callee.get<Value>());
+}
+
Operation::operand_range InvokeOp::getArgOperands() {
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 181c9e0a23bb7..2ad249773a6fe 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2576,6 +2576,11 @@ CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
}
+void spirv::FunctionCallOp::setCalleeFromCallable(
+ CallInterfaceCallable callee) {
+ (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+}
+
Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
return getArguments();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 60faf6dfe0e89..507f4aaaef649 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -495,11 +495,18 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
let extraClassDeclaration = [{
/// Return the callee of this operation.
::mlir::CallInterfaceCallable getCallableForCallee();
+
+ /// Set the callee for this operation.
+ void setCalleeFromCallable(::mlir::CallInterfaceCallable);
}];
let extraClassDefinition = [{
::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee");
}
+
+ void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+ (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+ }
}];
}
More information about the Mlir-commits
mailing list