[Mlir-commits] [mlir] 350d686 - [mlir] Print bbArgs of linalg.map/reduce/tranpose on the next line.
Alexander Belyaev
llvmlistbot at llvm.org
Thu Oct 27 01:19:26 PDT 2022
Author: Alexander Belyaev
Date: 2022-10-27T10:19:04+02:00
New Revision: 350d68644445f53551df1a4ddd69bd4f54f09fff
URL: https://github.com/llvm/llvm-project/commit/350d68644445f53551df1a4ddd69bd4f54f09fff
DIFF: https://github.com/llvm/llvm-project/commit/350d68644445f53551df1a4ddd69bd4f54f09fff.diff
LOG: [mlir] Print bbArgs of linalg.map/reduce/tranpose on the next line.
```
%mapped = linalg.map
ins(%arg0 : tensor<64xf32>)
outs(%arg1 : tensor<64xf32>)
(%in: f32) {
%0 = math.absf %in : f32
linalg.yield %0 : f32
}
%reduced = linalg.reduce
ins(%arg0 : tensor<16x32x64xf32>)
outs(%arg1 : tensor<16x64xf32>)
dimensions = [1]
(%in: f32, %init: f32) {
%0 = arith.addf %in, %init : f32
linalg.yield %0 : f32
}
%transposed = linalg.transpose
ins(%arg0 : tensor<16x32x64xf32>)
outs(%arg1 : tensor<32x64x16xf32>)
permutation = [1, 2, 0]
```
Differential Revision: https://reviews.llvm.org/D136818
Added:
Modified:
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 474e9955bdcce..524c72d239cf6 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -334,6 +334,12 @@ class OpAsmPrinter : public AsmPrinter {
/// operation.
virtual void printNewline() = 0;
+ /// Increase indentation.
+ virtual void increaseIndent() = 0;
+
+ /// Decrease indentation.
+ virtual void decreaseIndent() = 0;
+
/// Print a block argument in the usual format of:
/// %ssaName : type {attr1=42} loc("here")
/// where location printing is controlled by the standard internal option.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5d6dd379b2e40..a7e1938e2418a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -173,6 +173,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
}
+static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
+ ValueRange inputs,
+ ValueRange outputs) {
+ p.printNewline();
+ if (!inputs.empty())
+ p << "ins(" << inputs << " : " << inputs.getTypes() << ")";
+ p.printNewline();
+ if (!outputs.empty())
+ p << "outs(" << outputs << " : " << outputs.getTypes() << ")";
+}
//===----------------------------------------------------------------------===//
// Specific parsing and printing for named structured ops created by ods-gen.
//===----------------------------------------------------------------------===//
@@ -1335,16 +1345,20 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MapOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
- SmallVector<Value>(getOutputOperands()));
+ p.increaseIndent();
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getInputOperands()),
+ SmallVector<Value>(getOutputOperands()));
p.printOptionalAttrDict((*this)->getAttrs());
+ p.printNewline();
p << "(";
llvm::interleaveComma(getMapper().getArguments(), p,
[&](auto arg) { p.printRegionArgument(arg); });
p << ") ";
p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+ p.decreaseIndent();
}
LogicalResult MapOp::verify() {
@@ -1481,21 +1495,26 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
ArrayRef<int64_t> attributeValue) {
- p << " " << attributeName << " = [" << attributeValue << "] ";
+ p << attributeName << " = [" << attributeValue << "] ";
}
void ReduceOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
- SmallVector<Value>(getOutputOperands()));
+ p.increaseIndent();
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getInputOperands()),
+ SmallVector<Value>(getOutputOperands()));
+ p.printNewline();
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
+ p.printNewline();
p << "(";
llvm::interleaveComma(getCombiner().getArguments(), p,
[&](auto arg) { p.printRegionArgument(arg); });
p << ") ";
p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+ p.decreaseIndent();
}
LogicalResult ReduceOp::verify() {
@@ -1657,10 +1676,14 @@ void TransposeOp::getAsmResultNames(
}
void TransposeOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
- SmallVector<Value>(getOutputOperands()));
+ p.increaseIndent();
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getInputOperands()),
+ SmallVector<Value>(getOutputOperands()));
+ p.printNewline();
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
+ p.decreaseIndent();
}
LogicalResult TransposeOp::verify() {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 676f8133a8771..9a3d3e031dc31 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -716,6 +716,8 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
void printNewline() override {}
+ void increaseIndent() override {}
+ void decreaseIndent() override {}
void printOperand(Value) override {}
void printOperand(Value, raw_ostream &os) override {
// Users expect the output string to have at least the prefixed % to signal
@@ -2768,6 +2770,12 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
os.indent(currentIndent);
}
+ /// Increase indentation.
+ void increaseIndent() override { currentIndent += indentWidth; }
+
+ /// Decrease indentation.
+ void decreaseIndent() override { currentIndent -= indentWidth; }
+
/// Print a block argument in the usual format of:
/// %ssaName : type {attr1=42} loc("here")
/// where location printing is controlled by the standard internal option.
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index e71f566c307c8..58dec2be2373a 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -341,7 +341,7 @@ func.func @op_is_reading_but_following_ops_are_not(
func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%init: tensor<64xf32>) -> tensor<64xf32> {
// CHECK: linalg.map
- // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<64xf32
+ // CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
@@ -359,7 +359,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.func @reduce(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
// CHECK: linalg.reduce
- // CHECK-SAME: ins(%[[INPUT]] : memref<16x32x64xf32
+ // CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
@@ -378,7 +378,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
func.func @transpose(%input: tensor<16x32x64xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
// CHECK: linalg.transpose
- // CHECK-SAME: ins(%[[ARG0]] : memref<16x32x64xf32
+ // CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32
%transpose = linalg.transpose
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<32x64x16xf32>)
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4bea3f6d38376..6e1c26634c417 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -338,7 +338,13 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.return %add : tensor<64xf32>
}
// CHECK-LABEL: func @map_binary
-// CHECK: linalg.map
+// CHECK: linalg.map
+// CHECK-NEXT: ins
+// CHECK-NEXT: outs
+// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+// CHECK-NEXT: }
// -----
@@ -401,7 +407,14 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
func.return %reduce : tensor<16x64xf32>
}
// CHECK-LABEL: func @reduce
-// CHECK: linalg.reduce
+// CHECK: linalg.reduce
+// CHECK-NEXT: ins
+// CHECK-NEXT: outs
+// CHECK-NEXT: dimensions = [1]
+// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+// CHECK-NEXT: }
// -----
@@ -469,6 +482,10 @@ func.func @transpose(%input: tensor<16x32x64xf32>,
func.return %transpose : tensor<32x64x16xf32>
}
// CHECK-LABEL: func @transpose
+// CHECK: linalg.transpose
+// CHECK-NEXT: ins
+// CHECK-NEXT: outs
+// CHECK-NEXT: permutation
// -----
More information about the Mlir-commits
mailing list