[Mlir-commits] [mlir] e3f439e - [mlir][Linalg] NFC - Add result and bbArg pretty printing to linalg.reduce
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Oct 4 09:27:27 PDT 2022
Author: Nicolas Vasilache
Date: 2022-10-04T09:27:18-07:00
New Revision: e3f439ea20e4a78daf990062983695a21c3c1abd
URL: https://github.com/llvm/llvm-project/commit/e3f439ea20e4a78daf990062983695a21c3c1abd
DIFF: https://github.com/llvm/llvm-project/commit/e3f439ea20e4a78daf990062983695a21c3c1abd.diff
LOG: [mlir][Linalg] NFC - Add result and bbArg pretty printing to linalg.reduce
Differential Revision: https://reviews.llvm.org/D135152
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index d8c1e0ba1b344..268587e312077 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -338,6 +338,24 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return 0;
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input block arguments of the region.
+ }],
+ /*retTy=*/"Block::BlockArgListType",
+ /*methodName=*/"getRegionInputArgs",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ // MLIR currently does not support dependent interfaces or interface
+ // inheritance. By construction all ops with StructuredOpInterface must
+ // implement DestinationStyleOpInterface.
+ // TODO: reevalute the need for a cast when a better mechanism exists.
+ return getBlock()->getArguments().take_front(
+ cast<DestinationStyleOpInterface>(*this->getOperation())
+ .getNumInputs());
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the output block arguments of the region.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6234d33763132..3d1ee2f09b41e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -19,6 +19,7 @@ include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on ShapedType as their
@@ -229,8 +230,10 @@ def TensorOrMemref :
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
def ReduceOp : LinalgStructuredBase_Op<"reduce", [
- SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">
- ]> {
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
+ SameVariadicOperandSize,
+ SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Reduce operator";
let description = [{
Executes `combiner` on the `dimensions` of `inputs` and returns the
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6289369692535..3741e7d2ab087 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1187,6 +1187,19 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
// ReduceOp
//===----------------------------------------------------------------------===//
+void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
+ OpAsmSetValueNameFn setNameFn) {
+ for (Value v : getRegionInputArgs())
+ setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
+}
+
+void ReduceOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResults().front(), "reduced");
+}
+
ArrayAttr ReduceOp::getIteratorTypes() {
int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<StringRef> iteratorTypes(inputRank,
More information about the Mlir-commits
mailing list