[Mlir-commits] [mlir] e3f439e - [mlir][Linalg] NFC - Add result and bbArg pretty printing to linalg.reduce

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 4 09:27:27 PDT 2022


Author: Nicolas Vasilache
Date: 2022-10-04T09:27:18-07:00
New Revision: e3f439ea20e4a78daf990062983695a21c3c1abd

URL: https://github.com/llvm/llvm-project/commit/e3f439ea20e4a78daf990062983695a21c3c1abd
DIFF: https://github.com/llvm/llvm-project/commit/e3f439ea20e4a78daf990062983695a21c3c1abd.diff

LOG: [mlir][Linalg] NFC - Add result and bbArg pretty printing to linalg.reduce

Differential Revision: https://reviews.llvm.org/D135152

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index d8c1e0ba1b344..268587e312077 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -338,6 +338,24 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return 0;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the input block arguments of the region.
+      }],
+      /*retTy=*/"Block::BlockArgListType",
+      /*methodName=*/"getRegionInputArgs",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        // MLIR currently does not support dependent interfaces or interface
+        // inheritance. By construction all ops with StructuredOpInterface must
+        // implement DestinationStyleOpInterface.
+        // TODO: reevalute the need for a cast when a better mechanism exists.
+        return getBlock()->getArguments().take_front(
+            cast<DestinationStyleOpInterface>(*this->getOperation())
+                .getNumInputs());
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the output block arguments of the region.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6234d33763132..3d1ee2f09b41e 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -19,6 +19,7 @@ include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
 
 // Base Tablegen class for Linalg ops.
 // Linalg ops that correspond to library calls operate on ShapedType as their
@@ -229,8 +230,10 @@ def TensorOrMemref :
   AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
 
 def ReduceOp : LinalgStructuredBase_Op<"reduce", [
-      SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">
-    ]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
+    SameVariadicOperandSize,
+    SingleBlockImplicitTerminator<"YieldOp">]> {
   let summary = "Reduce operator";
   let description = [{
     Executes `combiner` on the `dimensions` of `inputs` and returns the

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6289369692535..3741e7d2ab087 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1187,6 +1187,19 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
 // ReduceOp
 //===----------------------------------------------------------------------===//
 
+void ReduceOp::getAsmBlockArgumentNames(Region &region,
+                                        OpAsmSetValueNameFn setNameFn) {
+  for (Value v : getRegionInputArgs())
+    setNameFn(v, "in");
+  for (Value v : getRegionOutputArgs())
+    setNameFn(v, "init");
+}
+
+void ReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults().front(), "reduced");
+}
+
 ArrayAttr ReduceOp::getIteratorTypes() {
   int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
   SmallVector<StringRef> iteratorTypes(inputRank,


        


More information about the Mlir-commits mailing list