[Mlir-commits] [mlir] ef545ef - [mlir][linalg] Reuploading: Apply shortened printing/parsing form to linalg.reduce.
Aliia Khasanova
llvmlistbot at llvm.org
Mon Jan 9 04:33:27 PST 2023
Author: Aliia Khasanova
Date: 2023-01-09T13:32:29+01:00
New Revision: ef545ef62a833152d8975ff16333b57cc41befcc
URL: https://github.com/llvm/llvm-project/commit/ef545ef62a833152d8975ff16333b57cc41befcc
DIFF: https://github.com/llvm/llvm-project/commit/ef545ef62a833152d8975ff16333b57cc41befcc.diff
LOG: [mlir][linalg] Reuploading: Apply shortened printing/parsing form to linalg.reduce.
Differential Revision: https://reviews.llvm.org/D141259
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5456ca1301c6..dd2a943184a3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -255,6 +255,16 @@ def MapOp : LinalgStructuredBase_Op<"map", [
linalg.yield %0: f32
}
```
+
+ Shortened print form is available. Applies to simple maps with one
+ non-yield operation inside the body.
+
+ The example above will be printed as:
+ ```
+ %add = linalg.map { arith.addf }
+ ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+ outs(%init: tensor<64xf32>)
+ ```
}];
let arguments = (ins
@@ -329,10 +339,22 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
outs(%init:tensor<16x64xf32>)
dimensions = [1]
(%in: f32, %out: f32) {
- %0 = arith.addf %in, %out: f32
+ %0 = arith.addf %out, %in: f32
linalg.yield %0: f32
}
```
+
+ Shortened print form is available. Applies to simple (not variadic) reduces
+ with one non-yield operation inside the body. Applies only if the operation
+ takes `%out` as the first argument.
+
+ The example above will be printed as:
+ ```
+ %reduce = linalg.reduce { arith.addf }
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ ```
}];
let arguments = (ins
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 48e7cfd64319..33f49c9febd8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1046,7 +1046,8 @@ void MapOp::build(
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
- ArrayRef<Value> operands) {
+ ArrayRef<Value> operands,
+ bool initFirst = false) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1056,14 +1057,24 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
b.getUnknownLoc());
}
+ SmallVector<Value> payloadOpOperands;
+ // If initFirst flag is enabled, we consider init as the first position of
+ // payload operands.
+ if (initFirst) {
+ payloadOpOperands.push_back(block.getArguments().back());
+ for (const auto &arg : block.getArguments().drop_back())
+ payloadOpOperands.push_back(arg);
+ } else {
+ payloadOpOperands = {block.getArguments().begin(),
+ block.getArguments().end()};
+ }
Operation *payloadOp = b.create(
result.location, b.getStringAttr(payloadOpName.getStringRef()),
- block.getArguments(),
+ payloadOpOperands,
TypeRange{
result.operands.back().getType().cast<ShapedType>().getElementType()},
payloadOpAttrs);
-
b.create<YieldOp>(result.location, payloadOp->getResults());
}
@@ -1102,7 +1113,9 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
// Retrieve the operation from the body, if it is the only one (except
// yield) and if it gets the same amount of arguments as the body does.
-static Operation *findPayloadOp(Block *body) {
+// If initFirst flag is enabled, we check that init takes the first position in
+// operands of payload.
+static Operation *findPayloadOp(Block *body, bool initFirst = false) {
if (body->getOperations().size() != 2)
return nullptr;
Operation &payload = body->getOperations().front();
@@ -1111,10 +1124,22 @@ static Operation *findPayloadOp(Block *body) {
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
return nullptr;
- for (const auto &[bbArg, operand] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
- if (bbArg != operand)
+ if (initFirst) {
+ // check init
+ if (payload.getOperands().back() != body->getArgument(0))
return nullptr;
+ // check rest
+ for (const auto &[operand, bbArg] :
+ llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
+ if (bbArg != operand)
+ return nullptr;
+ }
+ } else {
+ for (const auto &[operand, bbArg] :
+ llvm::zip(payload.getOperands(), body->getArguments())) {
+ if (bbArg != operand)
+ return nullptr;
+ }
}
return &payload;
}
@@ -1313,7 +1338,7 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
- makeArrayRef(result.operands));
+ makeArrayRef(result.operands), /*initFirst=*/true);
} else {
SmallVector<OpAsmParser::Argument> regionArgs;
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1336,7 +1361,7 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
+ Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
if (payloadOp) {
printShortForm(p, payloadOp);
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 87763c9b8101..7795b633c4b4 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -363,7 +363,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
outs(%init:tensor<16x64xf32>)
dimensions = [1]
(%in: f32, %out: f32) {
- %0 = arith.addf %in, %out: f32
+ %0 = arith.addf %out, %in: f32
linalg.yield %0: f32
}
func.return %reduce : tensor<16x64xf32>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 611d428506fa..c665366277da 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -414,7 +414,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
outs(%init:tensor<16x64xf32>)
dimensions = [1]
(%in: f32, %out: f32) {
- %0 = arith.addf %in, %out: f32
+ %0 = arith.addf %out, %in: f32
linalg.yield %0: f32
}
func.return %reduce : tensor<16x64xf32>
@@ -433,7 +433,7 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>,
outs(%init:memref<16x64xf32>)
dimensions = [1]
(%in: f32, %out: f32) {
- %0 = arith.addf %in, %out: f32
+ %0 = arith.addf %out, %in: f32
linalg.yield %0: f32
}
func.return
@@ -587,7 +587,7 @@ func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
outs(%init:tensor<16x64xf32>)
dimensions = [1]
(%in: f32, %out: f32) {
- %0 = arith.addf %in, %out fastmath<fast> : f32
+ %0 = arith.addf %out, %in fastmath<fast> : f32
linalg.yield %0: f32
}
func.return %reduce : tensor<16x64xf32>
More information about the Mlir-commits
mailing list