[Mlir-commits] [mlir] 350d686 - [mlir] Print bbArgs of linalg.map/reduce/tranpose on the next line.

Alexander Belyaev llvmlistbot at llvm.org
Thu Oct 27 01:19:26 PDT 2022


Author: Alexander Belyaev
Date: 2022-10-27T10:19:04+02:00
New Revision: 350d68644445f53551df1a4ddd69bd4f54f09fff

URL: https://github.com/llvm/llvm-project/commit/350d68644445f53551df1a4ddd69bd4f54f09fff
DIFF: https://github.com/llvm/llvm-project/commit/350d68644445f53551df1a4ddd69bd4f54f09fff.diff

LOG: [mlir] Print bbArgs of linalg.map/reduce/tranpose on the next line.

```
%mapped = linalg.map
  ins(%arg0 : tensor<64xf32>)
  outs(%arg1 : tensor<64xf32>)
  (%in: f32) {
    %0 = math.absf %in : f32
    linalg.yield %0 : f32
  }
%reduced = linalg.reduce
  ins(%arg0 : tensor<16x32x64xf32>)
  outs(%arg1 : tensor<16x64xf32>)
  dimensions = [1]
  (%in: f32, %init: f32) {
    %0 = arith.addf %in, %init : f32
    linalg.yield %0 : f32
  }
%transposed = linalg.transpose
  ins(%arg0 : tensor<16x32x64xf32>)
  outs(%arg1 : tensor<32x64x16xf32>)
  permutation = [1, 2, 0]
```

Differential Revision: https://reviews.llvm.org/D136818

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 474e9955bdcce..524c72d239cf6 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -334,6 +334,12 @@ class OpAsmPrinter : public AsmPrinter {
   /// operation.
   virtual void printNewline() = 0;
 
+  /// Increase indentation.
+  virtual void increaseIndent() = 0;
+
+  /// Decrease indentation.
+  virtual void decreaseIndent() = 0;
+
   /// Print a block argument in the usual format of:
   ///   %ssaName : type {attr1=42} loc("here")
   /// where location printing is controlled by the standard internal option.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5d6dd379b2e40..a7e1938e2418a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -173,6 +173,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
     p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
 }
 
+static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
+                                                    ValueRange inputs,
+                                                    ValueRange outputs) {
+  p.printNewline();
+  if (!inputs.empty())
+    p << "ins(" << inputs << " : " << inputs.getTypes() << ")";
+  p.printNewline();
+  if (!outputs.empty())
+    p << "outs(" << outputs << " : " << outputs.getTypes() << ")";
+}
 //===----------------------------------------------------------------------===//
 // Specific parsing and printing for named structured ops created by ods-gen.
 //===----------------------------------------------------------------------===//
@@ -1335,16 +1345,20 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void MapOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
-                               SmallVector<Value>(getOutputOperands()));
+  p.increaseIndent();
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getInputOperands()),
+      SmallVector<Value>(getOutputOperands()));
   p.printOptionalAttrDict((*this)->getAttrs());
 
+  p.printNewline();
   p << "(";
   llvm::interleaveComma(getMapper().getArguments(), p,
                         [&](auto arg) { p.printRegionArgument(arg); });
   p << ") ";
 
   p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+  p.decreaseIndent();
 }
 
 LogicalResult MapOp::verify() {
@@ -1481,21 +1495,26 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
 
 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
                                    ArrayRef<int64_t> attributeValue) {
-  p << " " << attributeName << " = [" << attributeValue << "] ";
+  p << attributeName << " = [" << attributeValue << "] ";
 }
 
 void ReduceOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
-                               SmallVector<Value>(getOutputOperands()));
+  p.increaseIndent();
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getInputOperands()),
+      SmallVector<Value>(getOutputOperands()));
+  p.printNewline();
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
 
+  p.printNewline();
   p << "(";
   llvm::interleaveComma(getCombiner().getArguments(), p,
                         [&](auto arg) { p.printRegionArgument(arg); });
   p << ") ";
 
   p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+  p.decreaseIndent();
 }
 
 LogicalResult ReduceOp::verify() {
@@ -1657,10 +1676,14 @@ void TransposeOp::getAsmResultNames(
 }
 
 void TransposeOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
-                               SmallVector<Value>(getOutputOperands()));
+  p.increaseIndent();
+  printCommonStructuredOpPartsWithNewLine(
+      p, SmallVector<Value>(getInputOperands()),
+      SmallVector<Value>(getOutputOperands()));
+  p.printNewline();
   printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
+  p.decreaseIndent();
 }
 
 LogicalResult TransposeOp::verify() {

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 676f8133a8771..9a3d3e031dc31 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -716,6 +716,8 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
   void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
   void printNewline() override {}
+  void increaseIndent() override {}
+  void decreaseIndent() override {}
   void printOperand(Value) override {}
   void printOperand(Value, raw_ostream &os) override {
     // Users expect the output string to have at least the prefixed % to signal
@@ -2768,6 +2770,12 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
     os.indent(currentIndent);
   }
 
+  /// Increase indentation.
+  void increaseIndent() override { currentIndent += indentWidth; }
+
+  /// Decrease indentation.
+  void decreaseIndent() override { currentIndent -= indentWidth; }
+
   /// Print a block argument in the usual format of:
   ///   %ssaName : type {attr1=42} loc("here")
   /// where location printing is controlled by the standard internal option.

diff  --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index e71f566c307c8..58dec2be2373a 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -341,7 +341,7 @@ func.func @op_is_reading_but_following_ops_are_not(
 func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
                       %init: tensor<64xf32>) -> tensor<64xf32> {
    // CHECK:      linalg.map
-   // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<64xf32
+   // CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
@@ -359,7 +359,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 func.func @reduce(%input: tensor<16x32x64xf32>,
                   %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
   // CHECK:     linalg.reduce
-  // CHECK-SAME: ins(%[[INPUT]] : memref<16x32x64xf32
+  // CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32
   %reduce = linalg.reduce
       ins(%input:tensor<16x32x64xf32>)
       outs(%init:tensor<16x64xf32>)
@@ -378,7 +378,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
 func.func @transpose(%input: tensor<16x32x64xf32>,
                      %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
   // CHECK:      linalg.transpose
-  // CHECK-SAME: ins(%[[ARG0]] : memref<16x32x64xf32
+  // CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32
   %transpose = linalg.transpose
       ins(%input:tensor<16x32x64xf32>)
       outs(%init:tensor<32x64x16xf32>)

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4bea3f6d38376..6e1c26634c417 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -338,7 +338,13 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   func.return %add : tensor<64xf32>
 }
 // CHECK-LABEL: func @map_binary
-//       CHECK:     linalg.map
+//       CHECK:   linalg.map
+//  CHECK-NEXT:   ins
+//  CHECK-NEXT:   outs
+//  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
+//  CHECK-NEXT:     arith.addf
+//  CHECK-NEXT:     linalg.yield
+//  CHECK-NEXT:   }
 
 // -----
 
@@ -401,7 +407,14 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
   func.return %reduce : tensor<16x64xf32>
 }
 // CHECK-LABEL: func @reduce
-//       CHECK:     linalg.reduce
+//       CHECK:   linalg.reduce
+//  CHECK-NEXT:   ins
+//  CHECK-NEXT:   outs
+//  CHECK-NEXT:   dimensions = [1]
+//  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
+//  CHECK-NEXT:     arith.addf
+//  CHECK-NEXT:     linalg.yield
+//  CHECK-NEXT:   }
 
 // -----
 
@@ -469,6 +482,10 @@ func.func @transpose(%input: tensor<16x32x64xf32>,
   func.return %transpose : tensor<32x64x16xf32>
 }
 // CHECK-LABEL: func @transpose
+//      CHECK:    linalg.transpose
+// CHECK-NEXT:    ins
+// CHECK-NEXT:    outs
+// CHECK-NEXT:    permutation
 
 // -----
 


        


More information about the Mlir-commits mailing list