[Mlir-commits] [mlir] aee2c23 - [mlir][linalg] Reuploading: add a shortened printing/parsing form for linalg.map and linalg.reduce.

Aliia Khasanova llvmlistbot at llvm.org
Thu Dec 22 07:04:18 PST 2022


Author: Aliia Khasanova
Date: 2022-12-22T16:03:35+01:00
New Revision: aee2c23066647095ce41c9dae70ad00c8527bef6

URL: https://github.com/llvm/llvm-project/commit/aee2c23066647095ce41c9dae70ad00c8527bef6
DIFF: https://github.com/llvm/llvm-project/commit/aee2c23066647095ce41c9dae70ad00c8527bef6.diff

LOG: [mlir][linalg] Reuploading: add a shortened printing/parsing form for linalg.map and linalg.reduce.

Differential Revision: https://reviews.llvm.org/D140535

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9587a5a32b0e4..b74537e95de2d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -174,16 +174,6 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
     p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
 }
 
-static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
-                                                    ValueRange inputs,
-                                                    ValueRange outputs) {
-  if (!inputs.empty()) {
-    p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
-  }
-  if (!outputs.empty()) {
-    p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
-  }
-}
 //===----------------------------------------------------------------------===//
 // Specific parsing and printing for named structured ops created by ods-gen.
 //===----------------------------------------------------------------------===//
@@ -1023,38 +1013,121 @@ void MapOp::build(
                        inputs, /*outputs=*/{}, bodyBuild);
 }
 
-ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
-  if (parseDstStyleOp(parser, result))
-    return failure();
+static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
+                                 const OperationName &payloadOpName,
+                                 const NamedAttrList &payloadOpAttrs,
+                                 ArrayRef<Value> operands) {
+  OpBuilder b(parser.getContext());
+  Region *body = result.addRegion();
+  Block &block = body->emplaceBlock();
+  b.setInsertionPointToStart(&block);
+  SmallVector<Value> bbArgs;
+  for (auto &operand : operands) {
+    block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
+                      b.getUnknownLoc());
+  }
 
-  SmallVector<OpAsmParser::Argument> regionArgs;
-  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
-                               /*allowType=*/true, /*allowAttrs=*/true)) {
-    return failure();
+  Operation *payloadOp = b.create(
+      result.location, b.getStringAttr(payloadOpName.getStringRef()),
+      block.getArguments(),
+      TypeRange{
+          result.operands.back().getType().cast<ShapedType>().getElementType()},
+      payloadOpAttrs);
+
+  b.create<YieldOp>(result.location, payloadOp->getResults());
+}
+
+ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
+  std::optional<OperationName> payloadOpName;
+  NamedAttrList payloadOpAttrs;
+  if (succeeded(parser.parseOptionalLBrace())) {
+    FailureOr<OperationName> operationName = parser.parseCustomOperationName();
+    if (failed(operationName))
+      return failure();
+    if (parser.parseOptionalAttrDict(payloadOpAttrs))
+      return failure();
+    payloadOpName = operationName.value();
+    if (parser.parseRBrace())
+      return failure();
   }
 
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, regionArgs))
+  if (parseDstStyleOp(parser, result))
     return failure();
 
+  if (payloadOpName.has_value()) {
+    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
+                         makeArrayRef(result.operands).drop_back());
+  } else {
+    SmallVector<OpAsmParser::Argument> regionArgs;
+    if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+                                 /*allowType=*/true, /*allowAttrs=*/true)) {
+      return failure();
+    }
+    Region *body = result.addRegion();
+    if (parser.parseRegion(*body, regionArgs))
+      return failure();
+  }
   return success();
 }
 
+// Retrieve the operation from the body, if it is the only one (except
+// yield) and if it gets the same amount of arguments as the body does.
+static Operation *findPayloadOp(Block *body) {
+  if (body->getOperations().size() != 2)
+    return nullptr;
+  Operation &payload = body->getOperations().front();
+  assert(isa<YieldOp>(body->getOperations().back()));
+
+  if (payload.getNumOperands() == 0 ||
+      payload.getNumOperands() != body->getNumArguments())
+    return nullptr;
+  for (const auto &[bbArg, operand] :
+       llvm::zip(payload.getOperands(), body->getArguments())) {
+    if (bbArg != operand)
+      return nullptr;
+  }
+  return &payload;
+}
+
+void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+  SmallVector<StringRef> elidedAttrs;
+  std::string attrToElide;
+  p << " { " << payloadOp->getName().getStringRef();
+  for (const auto &attr : payloadOp->getAttrs()) {
+    auto fastAttr = attr.getValue().dyn_cast<mlir::arith::FastMathFlagsAttr>();
+    if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
+      attrToElide = attr.getName().str();
+      elidedAttrs.push_back(attrToElide);
+      break;
+    }
+  }
+  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
+  p << " }";
+}
+
 void MapOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpPartsWithNewLine(
-      p, SmallVector<Value>(getDpsInputOperands()),
-      SmallVector<Value>(getDpsInitOperands()));
-  p.printOptionalAttrDict((*this)->getAttrs());
+  Block *mapper = getBody();
+  Operation *payloadOp = findPayloadOp(mapper);
+  if (payloadOp) {
+    printShortForm(p, payloadOp);
+  }
 
-  p.increaseIndent();
-  p.printNewline();
-  p << "(";
-  llvm::interleaveComma(getMapper().getArguments(), p,
-                        [&](auto arg) { p.printRegionArgument(arg); });
-  p << ") ";
+  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+                               SmallVector<Value>(getDpsInitOperands()));
+  p.printOptionalAttrDict((*this)->getAttrs());
 
-  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
-  p.decreaseIndent();
+  if (!payloadOp) {
+    // Print region if the payload op was not detected.
+    p.increaseIndent();
+    p.printNewline();
+    p << "(";
+    llvm::interleaveComma(mapper->getArguments(), p,
+                          [&](auto arg) { p.printRegionArgument(arg); });
+    p << ") ";
+
+    p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+    p.decreaseIndent();
+  }
 }
 
 LogicalResult MapOp::verify() {
@@ -1067,7 +1140,7 @@ LogicalResult MapOp::verify() {
                             "mapper, but got: "
                          << getInputs().size() << " and " << blockArgs.size();
 
-  // The parameters of mapper should all match the element type // of inputs.
+  // The parameters of mapper should all match the element type of inputs.
   for (const auto &[bbArgType, inputArg] :
        llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
     auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
@@ -1189,21 +1262,39 @@ static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
 }
 
 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
+  std::optional<OperationName> payloadOpName;
+  NamedAttrList payloadOpAttrs;
+  if (succeeded(parser.parseOptionalLBrace())) {
+    FailureOr<OperationName> operationName = parser.parseCustomOperationName();
+    if (failed(operationName))
+      return failure();
+    if (parser.parseOptionalAttrDict(payloadOpAttrs))
+      return failure();
+    payloadOpName = operationName.value();
+    if (parser.parseRBrace())
+      return failure();
+  }
+
   if (parseDstStyleOp(
           parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
             return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
           }))
     return failure();
 
-  SmallVector<OpAsmParser::Argument> regionArgs;
-  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
-                               /*allowType=*/true, /*allowAttrs=*/true)) {
-    return failure();
-  }
+  if (payloadOpName.has_value()) {
+    addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
+                         makeArrayRef(result.operands));
+  } else {
+    SmallVector<OpAsmParser::Argument> regionArgs;
+    if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+                                 /*allowType=*/true, /*allowAttrs=*/true)) {
+      return failure();
+    }
 
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, regionArgs))
-    return failure();
+    Region *body = result.addRegion();
+    if (parser.parseRegion(*body, regionArgs))
+      return failure();
+  }
 
   return success();
 }
@@ -1214,22 +1305,28 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
 }
 
 void ReduceOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpPartsWithNewLine(
-      p, SmallVector<Value>(getDpsInputOperands()),
-      SmallVector<Value>(getDpsInitOperands()));
+  Block *mapper = getBody();
+  Operation *payloadOp = findPayloadOp(mapper);
+  if (payloadOp) {
+    printShortForm(p, payloadOp);
+  }
 
+  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+                               SmallVector<Value>(getDpsInitOperands()));
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
-
-  p.increaseIndent();
-  p.printNewline();
-  p << "(";
-  llvm::interleaveComma(getCombiner().getArguments(), p,
-                        [&](auto arg) { p.printRegionArgument(arg); });
-  p << ") ";
-
-  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
-  p.decreaseIndent();
+  if (!payloadOp) {
+    // Print region if the payload op was not detected.
+    p.increaseIndent();
+    p.printNewline();
+    p << "(";
+    llvm::interleaveComma(mapper->getArguments(), p,
+                          [&](auto arg) { p.printRegionArgument(arg); });
+    p << ") ";
+
+    p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
+    p.decreaseIndent();
+  }
 }
 
 LogicalResult ReduceOp::verify() {
@@ -1378,9 +1475,8 @@ void TransposeOp::getAsmResultNames(
 }
 
 void TransposeOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpPartsWithNewLine(
-      p, SmallVector<Value>(getDpsInputOperands()),
-      SmallVector<Value>(getDpsInitOperands()));
+  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+                               SmallVector<Value>(getDpsInitOperands()));
   printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
 }
@@ -1493,9 +1589,8 @@ void BroadcastOp::getAsmResultNames(
 }
 
 void BroadcastOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpPartsWithNewLine(
-      p, SmallVector<Value>(getDpsInputOperands()),
-      SmallVector<Value>(getDpsInitOperands()));
+  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
+                               SmallVector<Value>(getDpsInitOperands()));
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
 }

diff  --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
index d418a92775cfd..87763c9b81014 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
@@ -340,7 +340,7 @@ func.func @op_is_reading_but_following_ops_are_not(
 // CHECK-SAME:  %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32
 func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
                       %init: tensor<64xf32>) -> tensor<64xf32> {
-   // CHECK:      linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32
+   // CHECK:      linalg.map { arith.addf } ins(%[[LHS]], %[[RHS]] : memref<64xf32
    %add = linalg.map
           ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
           outs(%init:tensor<64xf32>)
@@ -357,7 +357,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 // CHECK-SAME:  %[[INPUT:.*]]: memref<16x32x64xf32
 func.func @reduce(%input: tensor<16x32x64xf32>,
                   %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
-  // CHECK:     linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32
+  // CHECK:     linalg.reduce { arith.addf } ins(%[[INPUT]] : memref<16x32x64xf32
   %reduce = linalg.reduce
       ins(%input:tensor<16x32x64xf32>)
       outs(%init:tensor<16x64xf32>)

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index b1a614fb768a1..611d428506faf 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -356,12 +356,8 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
   func.return %add : tensor<64xf32>
 }
 // CHECK-LABEL: func @map_binary
-//       CHECK:   linalg.map ins
+//       CHECK:   linalg.map { arith.addf } ins
 //  CHECK-SAME:   outs
-//  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
-//  CHECK-NEXT:     arith.addf
-//  CHECK-NEXT:     linalg.yield
-//  CHECK-NEXT:   }
 
 // -----
 
@@ -424,13 +420,9 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
   func.return %reduce : tensor<16x64xf32>
 }
 // CHECK-LABEL: func @reduce
-//       CHECK:   linalg.reduce ins
+//       CHECK:   linalg.reduce { arith.addf } ins
 //  CHECK-SAME:   outs
 //  CHECK-SAME:   dimensions = [1]
-//  CHECK-NEXT:   (%{{.*}}: f32, %{{.*}}: f32) {
-//  CHECK-NEXT:     arith.addf
-//  CHECK-NEXT:     linalg.yield
-//  CHECK-NEXT:   }
 
 // -----
 
@@ -446,8 +438,10 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>,
       }
   func.return
 }
-// CHECK-LABEL: func @reduce_memref
-//       CHECK:     linalg.reduce
+// CHECK-LABEL: func @reduce
+//       CHECK:   linalg.reduce { arith.addf } ins
+//  CHECK-SAME:   outs
+//  CHECK-SAME:   dimensions = [1]
 
 // -----
 
@@ -467,6 +461,7 @@ func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
 }
 // CHECK-LABEL: func @variadic_reduce
 //       CHECK:     linalg.reduce
+//   CHECK-NOT:     { arith.addf
 
 // -----
 
@@ -484,8 +479,9 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
       }
   func.return
 }
-// CHECK-LABEL: func @variadic_reduce_memref
+//   CHECK-LABEL: func @variadic_reduce_memref
 //       CHECK:     linalg.reduce
+//   CHECK-NOT:     { arith.addf
 
 // -----
 
@@ -560,3 +556,46 @@ func.func @broadcast_memref(%input: memref<8x32xf32>,
 //      CHECK:    linalg.broadcast ins
 // CHECK-SAME:    outs
 // CHECK-SAME:    dimensions
+
+// -----
+
+func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
+                      %init: tensor<64xf32>) -> tensor<64xf32> {
+  %add = linalg.map
+          ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
+            linalg.yield %0: f32
+          }
+  func.return %add : tensor<64xf32>
+}
+
+// CHECK-LABEL: func @map_arith_with_attr
+// CHECK-NEXT:    %[[MAPPED:.*]] = linalg.map
+// CHECK-SAME:    { arith.addf {fastmath = #arith.fastmath<fast>} }
+// CHECK-SAME:    ins
+// CHECK-SAME:    outs
+// CHECK-NEXT:    return %[[MAPPED]] : tensor<64xf32>
+
+// -----
+
+func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
+                  %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+  %reduce = linalg.reduce
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<16x64xf32>)
+      dimensions = [1]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out fastmath<fast> : f32
+        linalg.yield %0: f32
+      }
+  func.return %reduce : tensor<16x64xf32>
+}
+// CHECK-LABEL: func @reduce_arith_with_attr
+// CHECK-NEXT:    %[[REDUCED:.*]] = linalg.reduce
+// CHECK-SAME:    { arith.addf {fastmath = #arith.fastmath<fast>} }
+// CHECK-SAME:    ins
+// CHECK-SAME:    outs
+// CHECK-SAME:    dimensions = [1]
+// CHECK-NEXT:    return %[[REDUCED]] : tensor<16x64xf32>


        


More information about the Mlir-commits mailing list