[Mlir-commits] [mlir] 5a451e4 - [mlir][linalg] adapt named op generalization to work with captures.
Tobias Gysi
llvmlistbot at llvm.org
Wed Apr 21 00:07:43 PDT 2021
Author: Tobias Gysi
Date: 2021-04-21T06:37:53Z
New Revision: 5a451e486f31de36913d6fc22a1b92b39caa3b0e
URL: https://github.com/llvm/llvm-project/commit/5a451e486f31de36913d6fc22a1b92b39caa3b0e
DIFF: https://github.com/llvm/llvm-project/commit/5a451e486f31de36913d6fc22a1b92b39caa3b0e.diff
LOG: [mlir][linalg] adapt named op generalization to work with captures.
Instead of always running the region builder check if the generalized op has a region attached. If yes inline the existing region instead of calling the region builder. This change circumvents a problem with named operations that have a region builder taking captures and the generalization pass not knowing about this captures.
Differential Revision: https://reviews.llvm.org/D100880
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index af3f393997e7c..8176888f0beb0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -27,24 +27,36 @@
#define DEBUG_TYPE "linalg-generalization"
using namespace mlir;
+using namespace mlir::linalg;
// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
// the given `namedOp` does not have a region builder.
-static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
- OpBuilder &builder) {
+static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp,
+ PatternRewriter &rewriter) {
+ SmallVector<AffineMap> indexingMaps = namedOp.getIndexingMaps();
+ SmallVector<StringRef> iterators = llvm::to_vector<4>(
+ namedOp.iterator_types().getAsValueRange<StringAttr>());
+ SmallVector<RankedTensorType> resultTypes = namedOp.getOutputTensorTypes();
+ SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
+
+ // Inline the existing region if the named operation has a region attached.
+ if (namedOp->getNumRegions() == 1) {
+ GenericOp genericOp = rewriter.create<GenericOp>(
+ namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
+ indexingMaps, iterators);
+ rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(),
+ genericOp.region().begin());
+ return genericOp;
+ }
+
+ // Otherwise use the region builder to generate a new region.
+ // TODO: Remove this path once all linag operations have a region attached.
auto regionBuilder = namedOp.getRegionBuilder();
if (!regionBuilder) {
LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
return nullptr;
}
-
- SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
- auto iterators = llvm::to_vector<4>(
- namedOp.iterator_types().getAsValueRange<StringAttr>());
- auto resultTypes = namedOp.getOutputTensorTypes();
- SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
-
- return builder.create<linalg::GenericOp>(
+ return rewriter.create<GenericOp>(
namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(),
indexingMaps, iterators,
[®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
@@ -57,27 +69,27 @@ namespace {
/// Base class for all linalg generalization patterns. A subclass must provide
/// the following method:
-/// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
+/// GenericOp createGenericOp(RootOp, PatternRewriter &)
/// for creating the generic op.
// TODO: remove this pattern after migrating all manually-written named ops
// into auto-generated ones.
template <typename ConcretePattern, typename RootOp>
struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
LinalgGeneralizationPattern(MLIRContext *context,
- linalg::LinalgTransformationFilter marker,
+ LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
LogicalResult matchAndRewrite(RootOp rootOp,
PatternRewriter &rewriter) const override {
- auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
+ auto linalgOp = dyn_cast<LinalgOp>(rootOp.getOperation());
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
auto *pattern = static_cast<const ConcretePattern *>(this);
- linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
+ GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
if (!genericOp)
return failure();
@@ -88,39 +100,38 @@ struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
}
private:
- linalg::LinalgTransformationFilter marker;
+ LinalgTransformationFilter marker;
};
struct GeneralizeConvOp
- : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
+ : public LinalgGeneralizationPattern<GeneralizeConvOp, ConvOp> {
using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
- linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
+ GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const;
};
/// Catch-all pattern for converting all named ops with a region builder into
/// linalg.generic.
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
LinalgNamedOpGeneralizationPattern(MLIRContext *context,
- linalg::LinalgTransformationFilter marker,
+ LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context),
marker(std::move(marker)) {}
LogicalResult matchAndRewrite(Operation *rootOp,
PatternRewriter &rewriter) const override {
- auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
+ auto linalgOp = dyn_cast<LinalgOp>(rootOp);
if (!linalgOp)
return failure();
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
return failure();
// No nothing to do for linalg.generic and linalg.indexed_generic.
- if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
+ if (isa<GenericOp, IndexedGenericOp>(rootOp))
return failure();
- linalg::GenericOp genericOp =
- createGenericOpFromNamedOp(linalgOp, rewriter);
+ GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter);
if (!genericOp)
return failure();
@@ -131,7 +142,7 @@ struct LinalgNamedOpGeneralizationPattern : RewritePattern {
}
private:
- linalg::LinalgTransformationFilter marker;
+ LinalgTransformationFilter marker;
};
struct LinalgGeneralizationPass
@@ -144,17 +155,17 @@ struct LinalgGeneralizationPass
void LinalgGeneralizationPass::runOnFunction() {
FuncOp func = getFunction();
RewritePatternSet patterns(&getContext());
- linalg::populateLinalgConvGeneralizationPatterns(patterns);
- linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns);
+ populateLinalgConvGeneralizationPatterns(patterns);
+ populateLinalgNamedOpsGeneralizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
}
-linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
- OpBuilder &builder) const {
+GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp,
+ OpBuilder &builder) const {
SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
auto iterators =
llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
- return builder.create<linalg::GenericOp>(
+ return builder.create<GenericOp>(
convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps,
iterators,
@@ -162,17 +173,17 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
Value mul =
bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
- bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
+ bodyBuilder.create<YieldOp>(bodyLoc, add);
});
}
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
- RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
+ RewritePatternSet &patterns, LinalgTransformationFilter marker) {
patterns.add<GeneralizeConvOp>(patterns.getContext(), marker);
}
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
- RewritePatternSet &patterns, linalg::LinalgTransformationFilter marker) {
+ RewritePatternSet &patterns, LinalgTransformationFilter marker) {
patterns.add<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
marker);
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e7b8e3aad9d95..b6231927df967 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -441,3 +441,23 @@ func @pooling_nhwc_min(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %ini
// CHECK-NEXT: %[[CMP:.+]] = cmpf olt, %[[BBARG0]], %[[BBARG2]] : f32
// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : f32
// CHECK-NEXT: linalg.yield %[[RES]] : f32
+
+// -----
+
+func @generalize_fill(%output: memref<?x?xf32>, %value : f32) {
+ linalg.fill(%output, %value) : memref<?x?xf32>, f32
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK: func @generalize_fill
+// CHECK-SAME: (%[[ARG0:.+]]: memref<?x?xf32>, %[[VAL:.+]]: f32)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
+// CHECK-SAME: outs(%{{.+}} : memref<?x?xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32)
+// CHECK-NEXT: linalg.yield %[[VAL]] : f32
More information about the Mlir-commits
mailing list