[Mlir-commits] [mlir] 11175b5 - [mlir][linalg] Print broadcast, map, reduce, transpose ins/outs on one line.

Alexander Belyaev llvmlistbot at llvm.org
Thu Dec 8 10:17:01 PST 2022


Author: Alexander Belyaev
Date: 2022-12-08T19:16:36+01:00
New Revision: 11175b55072418e148cec4bf0a6e858b2873f58f

URL: https://github.com/llvm/llvm-project/commit/11175b55072418e148cec4bf0a6e858b2873f58f
DIFF: https://github.com/llvm/llvm-project/commit/11175b55072418e148cec4bf0a6e858b2873f58f.diff

LOG: [mlir][linalg] Print broadcast, map, reduce, transpose ins/outs on one line.

Differential Revision: https://reviews.llvm.org/D139650

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 98b1406d98482..8b0540e10d01b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -178,12 +178,10 @@ static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
                                                     ValueRange inputs,
                                                     ValueRange outputs) {
   if (!inputs.empty()) {
-    p.printNewline();
-    p << "ins(" << inputs << " : " << inputs.getTypes() << ")";
+    p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
   }
   if (!outputs.empty()) {
-    p.printNewline();
-    p << "outs(" << outputs << " : " << outputs.getTypes() << ")";
+    p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
   }
 }
 //===----------------------------------------------------------------------===//
@@ -1041,12 +1039,12 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void MapOp::print(OpAsmPrinter &p) {
-  p.increaseIndent();
   printCommonStructuredOpPartsWithNewLine(
       p, SmallVector<Value>(getDpsInputOperands()),
       SmallVector<Value>(getDpsInitOperands()));
   p.printOptionalAttrDict((*this)->getAttrs());
 
+  p.increaseIndent();
   p.printNewline();
   p << "(";
   llvm::interleaveComma(getMapper().getArguments(), p,
@@ -1210,19 +1208,18 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
 
 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
                                    ArrayRef<int64_t> attributeValue) {
-  p << attributeName << " = [" << attributeValue << "] ";
+  p << ' ' << attributeName << " = [" << attributeValue << "] ";
 }
 
 void ReduceOp::print(OpAsmPrinter &p) {
-  p.increaseIndent();
   printCommonStructuredOpPartsWithNewLine(
       p, SmallVector<Value>(getDpsInputOperands()),
       SmallVector<Value>(getDpsInitOperands()));
-  p.printNewline();
 
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
 
+  p.increaseIndent();
   p.printNewline();
   p << "(";
   llvm::interleaveComma(getCombiner().getArguments(), p,
@@ -1379,15 +1376,11 @@ void TransposeOp::getAsmResultNames(
 }
 
 void TransposeOp::print(OpAsmPrinter &p) {
-  p.increaseIndent();
   printCommonStructuredOpPartsWithNewLine(
       p, SmallVector<Value>(getDpsInputOperands()),
       SmallVector<Value>(getDpsInitOperands()));
-  p.printNewline();
-
   printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
-  p.decreaseIndent();
 }
 
 LogicalResult TransposeOp::verify() {
@@ -1498,15 +1491,11 @@ void BroadcastOp::getAsmResultNames(
 }
 
 void BroadcastOp::print(OpAsmPrinter &p) {
-  p.increaseIndent();
   printCommonStructuredOpPartsWithNewLine(
       p, SmallVector<Value>(getDpsInputOperands()),
       SmallVector<Value>(getDpsInitOperands()));
-  p.printNewline();
-
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
-  p.decreaseIndent();
 }
 
 LogicalResult BroadcastOp::verify() {

diff  --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index ef2d218db6438..d418a92775cfd 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -340,8 +340,7 @@ func.func @op_is_reading_but_following_ops_are_not(
 // CHECK-SAME:  %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32
 func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
                       %init: tensor<64xf32>) -> tensor<64xf32> {
-   // CHECK:      linalg.map
-   // CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32
+   // CHECK:      linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
@@ -358,8 +357,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 // CHECK-SAME:  %[[INPUT:.*]]: memref<16x32x64xf32
 func.func @reduce(%input: tensor<16x32x64xf32>,
                   %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
-  // CHECK:     linalg.reduce
-  // CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32
+  // CHECK:     linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32
   %reduce = linalg.reduce
       ins(%input:tensor<16x32x64xf32>)
       outs(%init:tensor<16x64xf32>)
@@ -377,8 +375,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
 // CHECK-SAME:  %[[ARG0:.*]]: memref<16x32x64xf32
 func.func @transpose(%input: tensor<16x32x64xf32>,
                      %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
-  // CHECK:      linalg.transpose
-  // CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32
+  // CHECK:      linalg.transpose ins(%[[ARG0]] : memref<16x32x64xf32
   %transpose = linalg.transpose
       ins(%input:tensor<16x32x64xf32>)
       outs(%init:tensor<32x64x16xf32>)

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 8f0c83fe202e1..b1a614fb768a1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -336,8 +336,7 @@ func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
   func.return %add : tensor<64xf32>
 }
 // CHECK-LABEL: func @map_no_inputs
-//       CHECK:   linalg.map
-//  CHECK-NEXT:   outs
+//       CHECK:   linalg.map outs
 //  CHECK-NEXT:   () {
 //  CHECK-NEXT:     arith.constant
 //  CHECK-NEXT:     linalg.yield
@@ -357,9 +356,8 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   func.return %add : tensor<64xf32>
 }
 // CHECK-LABEL: func @map_binary
-//       CHECK:   linalg.map
-//  CHECK-NEXT:   ins
-//  CHECK-NEXT:   outs
+//       CHECK:   linalg.map ins
+//  CHECK-SAME:   outs
 //  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
 //  CHECK-NEXT:     arith.addf
 //  CHECK-NEXT:     linalg.yield
@@ -426,10 +424,9 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
   func.return %reduce : tensor<16x64xf32>
 }
 // CHECK-LABEL: func @reduce
-//       CHECK:   linalg.reduce
-//  CHECK-NEXT:   ins
-//  CHECK-NEXT:   outs
-//  CHECK-NEXT:   dimensions = [1]
+//       CHECK:   linalg.reduce ins
+//  CHECK-SAME:   outs
+//  CHECK-SAME:   dimensions = [1]
 //  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
 //  CHECK-NEXT:     arith.addf
 //  CHECK-NEXT:     linalg.yield
@@ -501,10 +498,9 @@ func.func @transpose(%input: tensor<16x32x64xf32>,
   func.return %transpose : tensor<32x64x16xf32>
 }
 // CHECK-LABEL: func @transpose
-//      CHECK:    linalg.transpose
-// CHECK-NEXT:    ins
-// CHECK-NEXT:    outs
-// CHECK-NEXT:    permutation
+//      CHECK:    linalg.transpose ins
+// CHECK-SAME:    outs
+// CHECK-SAME:    permutation
 
 // -----
 
@@ -529,10 +525,9 @@ func.func @broadcast_static_sizes(%input: tensor<8x32xf32>,
   func.return %bcast : tensor<8x16x32xf32>
 }
 // CHECK-LABEL: func @broadcast_static_sizes
-//      CHECK:    linalg.broadcast
-// CHECK-NEXT:    ins
-// CHECK-NEXT:    outs
-// CHECK-NEXT:    dimensions
+//      CHECK:    linalg.broadcast ins
+// CHECK-SAME:    outs
+// CHECK-SAME:    dimensions
 
 // -----
 
@@ -546,10 +541,9 @@ func.func @broadcast_with_dynamic_sizes(
   func.return %bcast : tensor<8x16x?xf32>
 }
 // CHECK-LABEL: func @broadcast_with_dynamic_sizes
-//      CHECK:    linalg.broadcast
-// CHECK-NEXT:    ins
-// CHECK-NEXT:    outs
-// CHECK-NEXT:    dimensions
+//      CHECK:    linalg.broadcast ins
+// CHECK-SAME:    outs
+// CHECK-SAME:    dimensions
 
 // -----
 
@@ -563,7 +557,6 @@ func.func @broadcast_memref(%input: memref<8x32xf32>,
 }
 
 // CHECK-LABEL: func @broadcast_memref
-//      CHECK:    linalg.broadcast
-// CHECK-NEXT:    ins
-// CHECK-NEXT:    outs
-// CHECK-NEXT:    dimensions
+//      CHECK:    linalg.broadcast ins
+// CHECK-SAME:    outs
+// CHECK-SAME:    dimensions


        


More information about the Mlir-commits mailing list