[Mlir-commits] [mlir] c0b775a - Revert "BEGIN_PUBLIC"
Mitch Phillips
llvmlistbot at llvm.org
Wed Dec 21 09:34:42 PST 2022
Author: Mitch Phillips
Date: 2022-12-21T09:32:54-08:00
New Revision: c0b775a5b506408bcdd9ffe31a51400a99734f2c
URL: https://github.com/llvm/llvm-project/commit/c0b775a5b506408bcdd9ffe31a51400a99734f2c
DIFF: https://github.com/llvm/llvm-project/commit/c0b775a5b506408bcdd9ffe31a51400a99734f2c.diff
LOG: Revert "BEGIN_PUBLIC"
This reverts commit a6d6d40d8bd062514fc379a6bf70fb1b7220be6f.
Reason: Broke the ASan/MSan bots. More information in phabricator:
https://reviews.llvm.org/D140406
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 56c7d844feda5..8b0540e10d01b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -174,6 +174,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
}
+static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
+ ValueRange inputs,
+ ValueRange outputs) {
+ if (!inputs.empty()) {
+ p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
+ }
+ if (!outputs.empty()) {
+ p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
+ }
+}
//===----------------------------------------------------------------------===//
// Specific parsing and printing for named structured ops created by ods-gen.
//===----------------------------------------------------------------------===//
@@ -1011,119 +1021,38 @@ void MapOp::build(
inputs, /*outputs=*/{}, bodyBuild);
}
-static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
- const OperationName &payloadOpName,
- const NamedAttrList &payloadOpAttrs,
- ArrayRef<Value> operands) {
- OpBuilder b(parser.getContext());
- Region *body = result.addRegion();
- Block &block = body->emplaceBlock();
- b.setInsertionPointToStart(&block);
- SmallVector<Value> bbArgs;
- for (auto &operand : operands) {
- block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
- b.getUnknownLoc());
- }
-
- Operation *payloadOp = b.create(
- result.location, b.getStringAttr(payloadOpName.getStringRef()),
- block.getArguments(),
- TypeRange{
- result.operands.back().getType().cast<ShapedType>().getElementType()},
- payloadOpAttrs);
-
- b.create<YieldOp>(result.location, payloadOp->getResults());
-}
-
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
- std::optional<OperationName> payloadOpName;
- NamedAttrList payloadOpAttrs;
- if (succeeded(parser.parseOptionalLBrace())) {
- FailureOr<OperationName> operationName = parser.parseCustomOperationName();
- if (failed(operationName))
- return failure();
- if (parser.parseOptionalAttrDict(payloadOpAttrs))
- return failure();
- payloadOpName = operationName.value();
- if (parser.parseRBrace())
- return failure();
- }
-
if (parseDstStyleOp(parser, result))
return failure();
- if (payloadOpName.has_value()) {
- addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
- makeArrayRef(result.operands).drop_back());
- } else {
- SmallVector<OpAsmParser::Argument> regionArgs;
- if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
- /*allowType=*/true, /*allowAttrs=*/true)) {
- return failure();
- }
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs))
- return failure();
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true, /*allowAttrs=*/true)) {
+ return failure();
}
- return success();
-}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-static Operation *findPayloadOp(Block *body) {
- if (body->getOperations().size() != 2)
- return nullptr;
- Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
-
- if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
- return nullptr;
- for (const auto &[bbArg, operand] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
- if (bbArg != operand)
- return nullptr;
- }
- return &payload;
-}
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
- SmallVector<StringRef> elidedAttrs;
- p << " { " << payloadOp->getName().getStringRef();
- for (const auto &attr : payloadOp->getAttrs()) {
- auto fastAttr = attr.getValue().dyn_cast<mlir::arith::FastMathFlagsAttr>();
- if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
- elidedAttrs.push_back(attr.getName().str());
- break;
- }
- }
- p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
- p << " }";
+ return success();
}
void MapOp::print(OpAsmPrinter &p) {
- Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
- }
-
- printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
- // Print region if the payload op was not detected.
- p.increaseIndent();
- p.printNewline();
- p << "(";
- llvm::interleaveComma(mapper->getArguments(), p,
- [&](auto arg) { p.printRegionArgument(arg); });
- p << ") ";
-
- p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
- p.decreaseIndent();
- }
+ p.increaseIndent();
+ p.printNewline();
+ p << "(";
+ llvm::interleaveComma(getMapper().getArguments(), p,
+ [&](auto arg) { p.printRegionArgument(arg); });
+ p << ") ";
+
+ p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+ p.decreaseIndent();
}
LogicalResult MapOp::verify() {
@@ -1136,7 +1065,7 @@ LogicalResult MapOp::verify() {
"mapper, but got: "
<< getInputs().size() << " and " << blockArgs.size();
- // The parameters of mapper should all match the element type of inputs.
+ // The parameters of mapper should all match the element type // of inputs.
for (const auto &[bbArgType, inputArg] :
llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
@@ -1258,40 +1187,22 @@ static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
}
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
- std::optional<OperationName> payloadOpName;
- NamedAttrList payloadOpAttrs;
- if (succeeded(parser.parseOptionalLBrace())) {
- FailureOr<OperationName> operationName = parser.parseCustomOperationName();
- if (failed(operationName))
- return failure();
- if (parser.parseOptionalAttrDict(payloadOpAttrs))
- return failure();
- payloadOpName = operationName.value();
- if (parser.parseRBrace())
- return failure();
- }
-
if (parseDstStyleOp(
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
}))
return failure();
- if (payloadOpName.has_value()) {
- addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
- makeArrayRef(result.operands));
- } else {
- SmallVector<OpAsmParser::Argument> regionArgs;
- if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
- /*allowType=*/true, /*allowAttrs=*/true)) {
- return failure();
- }
-
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs))
- return failure();
+ SmallVector<OpAsmParser::Argument> regionArgs;
+ if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true, /*allowAttrs=*/true)) {
+ return failure();
}
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
return success();
}
@@ -1301,28 +1212,22 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
}
void ReduceOp::print(OpAsmPrinter &p) {
- Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
- }
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
- printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
- // Print region if the payload op was not detected.
- p.increaseIndent();
- p.printNewline();
- p << "(";
- llvm::interleaveComma(mapper->getArguments(), p,
- [&](auto arg) { p.printRegionArgument(arg); });
- p << ") ";
-
- p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
- p.decreaseIndent();
- }
+
+ p.increaseIndent();
+ p.printNewline();
+ p << "(";
+ llvm::interleaveComma(getCombiner().getArguments(), p,
+ [&](auto arg) { p.printRegionArgument(arg); });
+ p << ") ";
+
+ p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+ p.decreaseIndent();
}
LogicalResult ReduceOp::verify() {
@@ -1471,8 +1376,9 @@ void TransposeOp::getAsmResultNames(
}
void TransposeOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
}
@@ -1585,8 +1491,9 @@ void BroadcastOp::getAsmResultNames(
}
void BroadcastOp::print(OpAsmPrinter &p) {
- printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
- SmallVector<Value>(getDpsInitOperands()));
+ printCommonStructuredOpPartsWithNewLine(
+ p, SmallVector<Value>(getDpsInputOperands()),
+ SmallVector<Value>(getDpsInitOperands()));
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index 87763c9b81014..d418a92775cfd 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -340,7 +340,7 @@ func.func @op_is_reading_but_following_ops_are_not(
// CHECK-SAME: %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32
func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%init: tensor<64xf32>) -> tensor<64xf32> {
- // CHECK: linalg.map { arith.addf } ins(%[[LHS]], %[[RHS]] : memref<64xf32
+ // CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
@@ -357,7 +357,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
// CHECK-SAME: %[[INPUT:.*]]: memref<16x32x64xf32
func.func @reduce(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
- // CHECK: linalg.reduce { arith.addf } ins(%[[INPUT]] : memref<16x32x64xf32
+ // CHECK: linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 611d428506faf..b1a614fb768a1 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -356,8 +356,12 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.return %add : tensor<64xf32>
}
// CHECK-LABEL: func @map_binary
-// CHECK: linalg.map { arith.addf } ins
+// CHECK: linalg.map ins
// CHECK-SAME: outs
+// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+// CHECK-NEXT: }
// -----
@@ -420,9 +424,13 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
func.return %reduce : tensor<16x64xf32>
}
// CHECK-LABEL: func @reduce
-// CHECK: linalg.reduce { arith.addf } ins
+// CHECK: linalg.reduce ins
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
+// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
+// CHECK-NEXT: arith.addf
+// CHECK-NEXT: linalg.yield
+// CHECK-NEXT: }
// -----
@@ -438,10 +446,8 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>,
}
func.return
}
-// CHECK-LABEL: func @reduce
-// CHECK: linalg.reduce { arith.addf } ins
-// CHECK-SAME: outs
-// CHECK-SAME: dimensions = [1]
+// CHECK-LABEL: func @reduce_memref
+// CHECK: linalg.reduce
// -----
@@ -461,7 +467,6 @@ func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
}
// CHECK-LABEL: func @variadic_reduce
// CHECK: linalg.reduce
-// CHECK-NOT: { arith.addf
// -----
@@ -479,9 +484,8 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
}
func.return
}
-// CHECK-LABEL: func @variadic_reduce_memref
+// CHECK-LABEL: func @variadic_reduce_memref
// CHECK: linalg.reduce
-// CHECK-NOT: { arith.addf
// -----
@@ -556,46 +560,3 @@ func.func @broadcast_memref(%input: memref<8x32xf32>,
// CHECK: linalg.broadcast ins
// CHECK-SAME: outs
// CHECK-SAME: dimensions
-
-// -----
-
-func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
- %init: tensor<64xf32>) -> tensor<64xf32> {
- %add = linalg.map
- ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
- outs(%init:tensor<64xf32>)
- (%lhs_elem: f32, %rhs_elem: f32) {
- %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
- linalg.yield %0: f32
- }
- func.return %add : tensor<64xf32>
-}
-
-// CHECK-LABEL: func @map_arith_with_attr
-// CHECK-NEXT: %[[MAPPED:.*]] = linalg.map
-// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath<fast>} }
-// CHECK-SAME: ins
-// CHECK-SAME: outs
-// CHECK-NEXT: return %[[MAPPED]] : tensor<64xf32>
-
-// -----
-
-func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
- %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
- %reduce = linalg.reduce
- ins(%input:tensor<16x32x64xf32>)
- outs(%init:tensor<16x64xf32>)
- dimensions = [1]
- (%in: f32, %out: f32) {
- %0 = arith.addf %in, %out fastmath<fast> : f32
- linalg.yield %0: f32
- }
- func.return %reduce : tensor<16x64xf32>
-}
-// CHECK-LABEL: func @reduce_arith_with_attr
-// CHECK-NEXT: %[[REDUCED:.*]] = linalg.reduce
-// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath<fast>} }
-// CHECK-SAME: ins
-// CHECK-SAME: outs
-// CHECK-SAME: dimensions = [1]
-// CHECK-NEXT: return %[[REDUCED]] : tensor<16x64xf32>
More information about the Mlir-commits
mailing list