[Mlir-commits] [mlir] 7eef3ea - Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`."

Oleg Shyshkov llvmlistbot at llvm.org
Fri Oct 28 00:57:50 PDT 2022


Author: Oleg Shyshkov
Date: 2022-10-28T09:56:59+02:00
New Revision: 7eef3ea5f4fe4f4cc461b191bac031e3962d0347

URL: https://github.com/llvm/llvm-project/commit/7eef3ea5f4fe4f4cc461b191bac031e3962d0347
DIFF: https://github.com/llvm/llvm-project/commit/7eef3ea5f4fe4f4cc461b191bac031e3962d0347.diff

LOG: Revert "[mlir][linalg] Add nicer builders for `map` and `reduce`."

This reverts commit aebde280476943e58f5bcd9993fdd7e36cdbe47e.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 510f8831f019a..1692a0f9c5492 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -267,12 +267,6 @@ def MapOp : LinalgStructuredBase_Op<"map", [
   let results = (outs Variadic<AnyTensor>:$result);
   let regions = (region SizedRegion<1>:$mapper);
 
-  let builders = [
-    OpBuilder<(ins "ValueRange":$inputs, "Value":$init,
-      "function_ref<void(OpBuilder &, Location, ValueRange)>",
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
-  ];
-
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Implement functions necessary for LinalgStructuredInterface.
     SmallVector<StringRef> getIteratorTypesArray();
@@ -347,13 +341,6 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
   let results = (outs Variadic<AnyTensor>);
   let regions = (region SizedRegion<1>:$combiner);
 
-  let builders = [
-    OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits,
-      "ArrayRef<int64_t>":$dimensions,
-      "function_ref<void(OpBuilder &, Location, ValueRange)>",
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
-  ];
-
   let extraClassDeclaration = structuredOpsBaseDecls # [{
     // Declare functions necessary for LinalgStructuredInterface.
     SmallVector<StringRef> getIteratorTypesArray();

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 443bc5cc7407e..03a959acfb2c7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -661,26 +661,6 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // GenericOp
 //===----------------------------------------------------------------------===//
 
-static void buildGenericRegion(
-    OpBuilder &builder, OperationState &result, ValueRange inputs,
-    ValueRange outputs,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
-  SmallVector<Type, 4> blockArgTypes;
-  SmallVector<Location, 4> blockArgLocs;
-  for (ValueRange container : {inputs, outputs}) {
-    for (Value v : container) {
-      blockArgTypes.push_back(getElementTypeOrSelf(v));
-      blockArgLocs.push_back(v.getLoc());
-    }
-  }
-
-  OpBuilder::InsertionGuard guard(builder);
-  auto &region = *result.regions.front();
-  Block *bodyBlock =
-      builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
-  bodyBuild(builder, result.location, bodyBlock->getArguments());
-}
-
 void GenericOp::getAsmBlockArgumentNames(Region &region,
                                          OpAsmSetValueNameFn setNameFn) {
   for (Value v : getRegionInputArgs())
@@ -698,8 +678,23 @@ void GenericOp::build(
   build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
         iteratorTypes, doc, libraryCall);
   result.addAttributes(attributes);
-  if (bodyBuild)
-    buildGenericRegion(builder, result, inputs, outputs, bodyBuild);
+  if (!bodyBuild)
+    return;
+
+  SmallVector<Type, 4> blockArgTypes;
+  SmallVector<Location, 4> blockArgLocs;
+  for (ValueRange container : {inputs, outputs}) {
+    for (Value v : container) {
+      blockArgTypes.push_back(getElementTypeOrSelf(v));
+      blockArgLocs.push_back(v.getLoc());
+    }
+  }
+
+  OpBuilder::InsertionGuard guard(builder);
+  auto &region = *result.regions.front();
+  Block *bodyBlock =
+      builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
+  bodyBuild(builder, result.location, bodyBlock->getArguments());
 }
 
 void GenericOp::build(
@@ -1334,22 +1329,6 @@ void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
     setNameFn(getResults().front(), "mapped");
 }
 
-void MapOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
-    ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, TypeRange{}, inputs, init);
-  result.addAttributes(attributes);
-
-  // Add output types for `RankedTensorType` output arguments.
-  Type initType = init.getType();
-  if (initType.isa<RankedTensorType>())
-    result.addTypes(initType);
-
-  if (bodyBuild)
-    buildGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild);
-}
-
 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseDstStyleOp(parser, result))
     return failure();
@@ -1457,25 +1436,6 @@ void ReduceOp::getAsmResultNames(
     setNameFn(getResults().front(), "reduced");
 }
 
-void ReduceOp::build(
-    OpBuilder &builder, OperationState &result, ValueRange inputs,
-    ValueRange inits, ArrayRef<int64_t> dimensions,
-    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
-    ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, TypeRange{}, inputs, inits, dimensions);
-  result.addAttributes(attributes);
-
-  // Add output types for `RankedTensorType` output arguments.
-  for (Value init : inits) {
-    Type initType = init.getType();
-    if (initType.isa<RankedTensorType>())
-      result.addTypes(initType);
-  }
-
-  if (bodyBuild)
-    buildGenericRegion(builder, result, inputs, inits, bodyBuild);
-}
-
 SmallVector<StringRef> ReduceOp::getIteratorTypesArray() {
   int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
   SmallVector<StringRef> iteratorTypes(inputRank,
@@ -1658,32 +1618,45 @@ TransposeOp::getRegionBuilder() {
   };
 }
 
-void TransposeOp::build(::mlir::OpBuilder &builder,
-                        ::mlir::OperationState &result, Value input, Value init,
-                        DenseI64ArrayAttr permutation,
-                        ArrayRef<NamedAttribute> attributes) {
-  result.addOperands(input);
-  result.addOperands(init);
-  result.addAttribute(getPermutationAttrName(result.name), permutation);
-  result.addAttributes(attributes);
+void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
+                               ::mlir::OperationState &odsState) {
+  Region *region = odsState.addRegion();
 
-  // Add output types for `RankedTensorType` output arguments.
-  Type initType = init.getType();
-  if (initType.isa<RankedTensorType>())
-    result.addTypes(initType);
+  SmallVector<Type> argTypes;
+  SmallVector<Location> argLocs;
+  for (auto t : odsState.operands) {
+    argTypes.push_back(getElementTypeOrSelf(t));
+    argLocs.push_back(opBuilder.getUnknownLoc());
+  }
 
-  buildGenericRegion(builder, result, input, init,
-                     [&](OpBuilder &b, Location loc, ValueRange args) {
-                       b.create<linalg::YieldOp>(loc, args[0]);
-                     });
+  // RAII.
+  OpBuilder::InsertionGuard guard(opBuilder);
+  Block *body =
+      opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
+
+  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
+  getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
 }
 
-void TransposeOp::build(::mlir::OpBuilder &builder,
-                        ::mlir::OperationState &result, Value input, Value init,
-                        ArrayRef<int64_t> permutation,
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+                        ::mlir::OperationState &odsState, Value input,
+                        Value init, DenseI64ArrayAttr permutation,
                         ArrayRef<NamedAttribute> attributes) {
-  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
-        attributes);
+  odsState.addOperands(input);
+  odsState.addOperands(init);
+  odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
+  odsState.addAttributes(attributes);
+  odsState.addTypes(init.getType());
+
+  createRegion(odsBuilder, odsState);
+}
+
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+                        ::mlir::OperationState &odsState, Value input,
+                        Value init, ArrayRef<int64_t> permutation,
+                        ArrayRef<NamedAttribute> attributes) {
+  build(odsBuilder, odsState, input, init,
+        odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
 }
 
 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -1693,13 +1666,8 @@ ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
           })))
     return failure();
 
-  (void)result.addRegion();
-  OpBuilder builder(parser.getContext());
-  buildGenericRegion(builder, result, /*inputs=*/result.operands,
-                     /*outputs=*/{},
-                     [&](OpBuilder &b, Location loc, ValueRange args) {
-                       b.create<linalg::YieldOp>(loc, args[0]);
-                     });
+  OpBuilder opBuilder(parser.getContext());
+  createRegion(opBuilder, result);
   return success();
 }
 


        


More information about the Mlir-commits mailing list