[Mlir-commits] [mlir] 5a451e4 - [mlir][linalg] adapt named op generalization to work with captures.

Tobias Gysi llvmlistbot at llvm.org
Wed Apr 21 00:07:43 PDT 2021


Author: Tobias Gysi
Date: 2021-04-21T06:37:53Z
New Revision: 5a451e486f31de36913d6fc22a1b92b39caa3b0e

URL: https://github.com/llvm/llvm-project/commit/5a451e486f31de36913d6fc22a1b92b39caa3b0e
DIFF: https://github.com/llvm/llvm-project/commit/5a451e486f31de36913d6fc22a1b92b39caa3b0e.diff

LOG: [mlir][linalg] adapt named op generalization to work with captures.

Instead of always running the region builder check if the generalized op has a region attached. If yes inline the existing region instead of calling the region builder. This change circumvents a problem with named operations that have a region builder taking captures and the generalization pass not knowing about this captures.

Differential Revision: https://reviews.llvm.org/D100880

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index af3f393997e7c..8176888f0beb0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -27,24 +27,36 @@
 #define DEBUG_TYPE "linalg-generalization"
 
 using namespace mlir;
+using namespace mlir::linalg;
 
 // Creates a linalg.generic op from the given `namedOp`. Returns a null op if
 // the given `namedOp` does not have a region builder.
-static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
-                                                    OpBuilder &builder) {
+static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
+                                            PatternRewriter &rewriter) {
+  SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
+  SmallVector<StringRef> iterators = llvm::to_vector<4>(
+      namedOp.iterator_types().getAsValueRange<StringAttr>());
+  SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes();
+  SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
+
+  // Inline the existing region if the named operation has a region attached.
+  if (namedOp->getNumRegions() == 1) {
+    GenericOp genericOp = rewriter.create<GenericOp>(
+        namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
+        indexingMaps, iterators);
+    rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
+                                genericOp.region().begin());
+    return genericOp;
+  }
+
+  // Otherwise use the region builder to generate a new region.
+  // TODO: Remove this path once all linag operations have a region attached.
   auto regionBuilder = namedOp.getRegionBuilder();
   if (!regionBuilder) {
     LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
     return nullptr;
   }
-
-  SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
-  auto iterators = llvm::to_vector<4>(
-      namedOp.iterator_types().getAsValueRange<StringAttr>());
-  auto resultTypes = namedOp.getOutputTensorTypes();
-  SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
-
-  return builder.create<linalg::GenericOp>(
+  return rewriter.create<GenericOp>(
       namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
       indexingMaps, iterators,
       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
@@ -57,27 +69,27 @@ namespace {
 
 /// Base class for all linalg generalization patterns. A subclass must provide
 /// the following method:
-///   linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
+///   GenericOp createGenericOp(RootOp, PatternRewriter &)
 /// for creating the generic op.
 // TODO: remove this pattern after migrating all manually-written named ops
 // into auto-generated ones.
 template <typename ConcretePattern, typename RootOp>
 struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
   LinalgGeneralizationPattern(MLIRContext *context,
-                              linalg::LinalgTransformationFilter marker,
+                              LinalgTransformationFilter marker,
                               PatternBenefit benefit = 1)
       : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
 
   LogicalResult matchAndRewrite(RootOp rootOp,
                                 PatternRewriter &rewriter) const override {
-    auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
+    auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation());
     if (!linalgOp)
       return failure();
     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
       return failure();
 
     auto *pattern = static_cast<const ConcretePattern *>(this);
-    linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
+    GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
     if (!genericOp)
       return failure();
 
@@ -88,39 +100,38 @@ struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
   }
 
 private:
-  linalg::LinalgTransformationFilter marker;
+  LinalgTransformationFilter marker;
 };
 
 struct GeneralizeConvOp
-    : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
+    : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> {
   using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
 
-  linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
+  GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
 };
 
 /// Catch-all pattern for converting all named ops with a region builder into
 /// linalg.generic.
 struct LinalgNamedOpGeneralizationPattern : RewritePattern {
   LinalgNamedOpGeneralizationPattern(MLIRContext *context,
-                                     linalg::LinalgTransformationFilter marker,
+                                     LinalgTransformationFilter marker,
                                      PatternBenefit benefit = 1)
       : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
         marker(std::move(marker)) {}
 
   LogicalResult matchAndRewrite(Operation *rootOp,
                                 PatternRewriter &rewriter) const override {
-    auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
+    auto linalgOp = dyn_cast<LinalgOp>(rootOp);
     if (!linalgOp)
       return failure();
     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
       return failure();
 
     // No nothing to do for linalg.generic and linalg.indexed_generic.
-    if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
+    if (isa<GenericOp, IndexedGenericOp>(rootOp))
       return failure();
 
-    linalg::GenericOp genericOp =
-        createGenericOpFromNamedOp(linalgOp, rewriter);
+    GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
     if (!genericOp)
       return failure();
 
@@ -131,7 +142,7 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
   }
 
 private:
-  linalg::LinalgTransformationFilter marker;
+  LinalgTransformationFilter marker;
 };
 
 struct LinalgGeneralizationPass
@@ -144,17 +155,17 @@ struct LinalgGeneralizationPass
 void LinalgGeneralizationPass::runOnFunction() {
   FuncOp func = getFunction();
   RewritePatternSet patterns(&getContext());
-  linalg::populateLinalgConvGeneralizationPatterns(patterns);
-  linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns);
+  populateLinalgConvGeneralizationPatterns(patterns);
+  populateLinalgNamedOpsGeneralizationPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
 }
 
-linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
-                                                    OpBuilder &builder) const {
+GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp,
+                                            OpBuilder &builder) const {
   SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
   auto iterators =
       llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
-  return builder.create<linalg::GenericOp>(
+  return builder.create<GenericOp>(
       convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
       convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps,
       iterators,
@@ -162,17 +173,17 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
         Value mul =
             bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
         Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
-        bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
+        bodyBuilder.create<YieldOp>(bodyLoc, add);
       });
 }
 
 void mlir::linalg::populateLinalgConvGeneralizationPatterns(
-    RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
+    RewritePatternSet &patterns, LinalgTransformationFilter marker) {
   patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
 }
 
 void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
-    RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
+    RewritePatternSet &patterns, LinalgTransformationFilter marker) {
   patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
                                                    marker);
 }

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e7b8e3aad9d95..b6231927df967 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -441,3 +441,23 @@ func @pooling_nhwc_min(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %ini
 // CHECK-NEXT:      %[[CMP:.+]] = cmpf olt, %[[BBARG0]], %[[BBARG2]] : f32
 // CHECK-NEXT:      %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : f32
 // CHECK-NEXT:      linalg.yield %[[RES]] : f32
+
+// -----
+
+func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
+  linalg.fill(%output, %value) : memref<?x?xf32>, f32
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK: func @generalize_fill
+// CHECK-SAME: (%[[ARG0:.+]]: memref<?x?xf32>, %[[VAL:.+]]: f32)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32)
+// CHECK-NEXT:      linalg.yield %[[VAL]] : f32


        


More information about the Mlir-commits mailing list