[Mlir-commits] [mlir] a6d6d40 - BEGIN_PUBLIC
Aliia Khasanova
llvmlistbot at llvm.org
Wed Dec 21 05:39:39 PST 2022
Author: Aliia Khasanova
Date: 2022-12-21T14:39:30+01:00
New Revision: a6d6d40d8bd062514fc379a6bf70fb1b7220be6f
URL: https://github.com/llvm/llvm-project/commit/a6d6d40d8bd062514fc379a6bf70fb1b7220be6f
DIFF: https://github.com/llvm/llvm-project/commit/a6d6d40d8bd062514fc379a6bf70fb1b7220be6f.diff
LOG: BEGIN_PUBLIC
Add a shortened printing/parsing form for linalg.map and linalg.reduce.
END_PUBLIC
Differential Revision: https://reviews.llvm.org/D140406
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8b0540e10d01b..56c7d844feda5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -174,16 +174,6 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
}
-static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
- ValueRange inputs,
- ValueRange outputs) {
- if (!inputs.empty()) {
- p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
- }
- if (!outputs.empty()) {
- p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
- }
-}
//===----------------------------------------------------------------------===//
// Specific parsing and printing for named structured ops created by ods-gen.
//===----------------------------------------------------------------------===//
@@ -1021,38 +1011,119 @@ void MapOp::build(
inputs, /*outputs=*/{}, bodyBuild);
}
+static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
+ const OperationName &payloadOpName,
+ const NamedAttrList &payloadOpAttrs,
+ ArrayRef<Value> operands) {
+ OpBuilder b(parser.getContext());
+ Region *body = result.addRegion();
+ Block &block = body->emplaceBlock();
+ b.setInsertionPointToStart(&block);
+ SmallVector<Value> bbArgs;
+ for (auto &operand : operands) {
+ block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
+ b.getUnknownLoc());
+ }
+
+ Operation *payloadOp = b.create(
+ result.location, b.getStringAttr(payloadOpName.getStringRef()),
+ block.getArguments(),
+ TypeRange{
+ result.operands.back().getType().cast<ShapedType>().getElementType()},
+ payloadOpAttrs);
+
+ b.create<YieldOp>(result.location, payloadOp->getResults());
+}
+
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
+ std::optional<OperationName> payloadOpName;
+ NamedAttrList payloadOpAttrs;
+ if (succeeded(parser.parseOptionalLBrace())) {
+ FailureOr<OperationName> operationName = parser.parseCustomOperationName();
+ if (failed(operationName))
+ return failure();
+ if (parser.parseOptionalAttrDict(payloadOpAttrs))
+ return failure();
+ payloadOpName = operationName.value();
+ if (parser.parseRBrace())
+ return failure();
+ }
+
if (parseDstStyleOp(parser, result))
return failure();
- SmallVector<OpAsmParser::Argument> regionArgs;
- if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
- /*allowType=*/true, /*allowAttrs=*/true)) {
- return failure();
+ if (payloadOpName.has_value()) {
+ addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
+ makeArrayRef(result.operands).drop_back());
+ } else {
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true, /*allowAttrs=*/true)) {
+ return failure();
+ }
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
}
+ return success();
+}
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs))
- return failure();
+// Retrieve the operation from the body, if it is the only one (except
+// yield) and if it gets the same amount of arguments as the body does.
+static Operation *findPayloadOp(Block *body) {
+ if (body->getOperations().size() != 2)
+ return nullptr;
+ Operation &payload = body->getOperations().front();
+ assert(isa<YieldOp>(body->getOperations().back()));
+
+ if (payload.getNumOperands() == 0 ||
+ payload.getNumOperands() != body->getNumArguments())
+ return nullptr;
+ for (const auto &[bbArg, operand] :
+ llvm::zip(payload.getOperands(), body->getArguments())) {
+ if (bbArg != operand)
+ return nullptr;
+ }
+ return &payload;
+}
- return success();
+void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+ SmallVector<StringRef> elidedAttrs;
+ p << " { " << payloadOp->getName().getStringRef();
+ for (const auto &attr : payloadOp->getAttrs()) {
+ auto fastAttr = attr.getValue().dyn_cast<mlir::arith::FastMathFlagsAttr>();
+ if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
+ elidedAttrs.push_back(attr.getName().str());
+ break;
+ }
+ }
+ p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
+ p << " }";
}
void MapOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpPartsWithNewLine(
- p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
- p.printOptionalAttrDict((*this)->getAttrs());
+ Block *mapper = getBody();
+ Operation *payloadOp = findPayloadOp(mapper);
+ if (payloadOp) {
+ printShortForm(p, payloadOp);
+ }
- p.increaseIndent();
- p.printNewline();
- p << "(";
- llvm::interleaveComma(getMapper().getArguments(), p,
- [&](auto arg) { p.printRegionArgument(arg); });
- p << ") ";
+ printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
+ p.printOptionalAttrDict((*this)->getAttrs());
- p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
- p.decreaseIndent();
+ if (!payloadOp) {
+ // Print region if the payload op was not detected.
+ p.increaseIndent();
+ p.printNewline();
+ p << "(";
+ llvm::interleaveComma(mapper->getArguments(), p,
+ [&](auto arg) { p.printRegionArgument(arg); });
+ p << ") ";
+
+ p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+ p.decreaseIndent();
+ }
}
LogicalResult MapOp::verify() {
@@ -1065,7 +1136,7 @@ LogicalResult MapOp::verify() {
"mapper, but got: "
<< getInputs().size() << " and " << blockArgs.size();
- // The parameters of mapper should all match the element type // of inputs.
+ // The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
@@ -1187,21 +1258,39 @@ static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
}
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
+ std::optional<OperationName> payloadOpName;
+ NamedAttrList payloadOpAttrs;
+ if (succeeded(parser.parseOptionalLBrace())) {
+ FailureOr<OperationName> operationName = parser.parseCustomOperationName();
+ if (failed(operationName))
+ return failure();
+ if (parser.parseOptionalAttrDict(payloadOpAttrs))
+ return failure();
+ payloadOpName = operationName.value();
+ if (parser.parseRBrace())
+ return failure();
+ }
+
if (parseDstStyleOp(
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
}))
return failure();
- SmallVector<OpAsmParser::Argument> regionArgs;
- if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
- /*allowType=*/true, /*allowAttrs=*/true)) {
- return failure();
- }
+ if (payloadOpName.has_value()) {
+ addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
+ makeArrayRef(result.operands));
+ } else {
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true, /*allowAttrs=*/true)) {
+ return failure();
+ }
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs))
- return failure();
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+ }
return success();
}
@@ -1212,22 +1301,28 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
}
void ReduceOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpPartsWithNewLine(
- p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
+ Block *mapper = getBody();
+ Operation *payloadOp = findPayloadOp(mapper);
+ if (payloadOp) {
+ printShortForm(p, payloadOp);
+ }
+ printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
-
- p.increaseIndent();
- p.printNewline();
- p << "(";
- llvm::interleaveComma(getCombiner().getArguments(), p,
- [&](auto arg) { p.printRegionArgument(arg); });
- p << ") ";
-
- p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
- p.decreaseIndent();
+ if (!payloadOp) {
+ // Print region if the payload op was not detected.
+ p.increaseIndent();
+ p.printNewline();
+ p << "(";
+ llvm::interleaveComma(mapper->getArguments(), p,
+ [&](auto arg) { p.printRegionArgument(arg); });
+ p << ") ";
+
+ p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+ p.decreaseIndent();
+ }
}
LogicalResult ReduceOp::verify() {
@@ -1376,9 +1471,8 @@ void TransposeOp::getAsmResultNames(
}
void TransposeOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpPartsWithNewLine(
- p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
+ printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
}
@@ -1491,9 +1585,8 @@ void BroadcastOp::getAsmResultNames(
}
void BroadcastOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpPartsWithNewLine(
- p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
+ printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index d418a92775cfd..87763c9b81014 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -340,7 +340,7 @@ func.func @op_is_reading_but_following_ops_are_not(
// CHECK-SAME: %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32
func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%init: tensor<64xf32>) -> tensor<64xf32> {
- // CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32
+ // CHECK: linalg.map { arith.addf } ins(%[[LHS]], %[[RHS]] : memref<64xf32
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
@@ -357,7 +357,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
// CHECK-SAME: %[[INPUT:.*]]: memref<16x32x64xf32
func.func @reduce(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
- // CHECK: linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32
+ // CHECK: linalg.reduce { arith.addf } ins(%[[INPUT]] : memref<16x32x64xf32
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index b1a614fb768a1..611d428506faf 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -356,12 +356,8 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.return %add : tensor<64xf32>
}
// CHECK-LABEL: func @map_binary
-// CHECK: linalg.map ins
+// CHECK: linalg.map { arith.addf } ins
// CHECK-SAME: outs
-// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
-// CHECK-NEXT: arith.addf
-// CHECK-NEXT: linalg.yield
-// CHECK-NEXT: }
// -----
@@ -424,13 +420,9 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
func.return %reduce : tensor<16x64xf32>
}
// CHECK-LABEL: func @reduce
-// CHECK: linalg.reduce ins
+// CHECK: linalg.reduce { arith.addf } ins
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
-// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
-// CHECK-NEXT: arith.addf
-// CHECK-NEXT: linalg.yield
-// CHECK-NEXT: }
// -----
@@ -446,8 +438,10 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>,
}
func.return
}
-// CHECK-LABEL: func @reduce_memref
-// CHECK: linalg.reduce
+// CHECK-LABEL: func @reduce
+// CHECK: linalg.reduce { arith.addf } ins
+// CHECK-SAME: outs
+// CHECK-SAME: dimensions = [1]
// -----
@@ -467,6 +461,7 @@ func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
}
// CHECK-LABEL: func @variadic_reduce
// CHECK: linalg.reduce
+// CHECK-NOT: { arith.addf
// -----
@@ -484,8 +479,9 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
}
func.return
}
-// CHECK-LABEL: func @variadic_reduce_memref
+// CHECK-LABEL: func @variadic_reduce_memref
// CHECK: linalg.reduce
+// CHECK-NOT: { arith.addf
// -----
@@ -560,3 +556,46 @@ func.func @broadcast_memref(%input: memref<8x32xf32>,
// CHECK: linalg.broadcast ins
// CHECK-SAME: outs
// CHECK-SAME: dimensions
+
+// -----
+
+func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
+ %init: tensor<64xf32>) -> tensor<64xf32> {
+ %add = linalg.map
+ ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
+ outs(%init:tensor<64xf32>)
+ (%lhs_elem: f32, %rhs_elem: f32) {
+ %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
+ linalg.yield %0: f32
+ }
+ func.return %add : tensor<64xf32>
+}
+
+// CHECK-LABEL: func @map_arith_with_attr
+// CHECK-NEXT: %[[MAPPED:.*]] = linalg.map
+// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath<fast>} }
+// CHECK-SAME: ins
+// CHECK-SAME: outs
+// CHECK-NEXT: return %[[MAPPED]] : tensor<64xf32>
+
+// -----
+
+func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ (%in: f32, %out: f32) {
+ %0 = arith.addf %in, %out fastmath<fast> : f32
+ linalg.yield %0: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+// CHECK-LABEL: func @reduce_arith_with_attr
+// CHECK-NEXT: %[[REDUCED:.*]] = linalg.reduce
+// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath<fast>} }
+// CHECK-SAME: ins
+// CHECK-SAME: outs
+// CHECK-SAME: dimensions = [1]
+// CHECK-NEXT: return %[[REDUCED]] : tensor<16x64xf32>
More information about the Mlir-commits
mailing list