[Mlir-commits] [mlir] 9e39a5d - [mlir][linalg] Start a named ops to generic ops pass
Lei Zhang
llvmlistbot at llvm.org
Thu Nov 19 06:21:17 PST 2020
Author: Lei Zhang
Date: 2020-11-19T09:21:06-05:00
New Revision: 9e39a5d9a68af70c58ac415e51e6b12cd85f9af2
URL: https://github.com/llvm/llvm-project/commit/9e39a5d9a68af70c58ac415e51e6b12cd85f9af2
DIFF: https://github.com/llvm/llvm-project/commit/9e39a5d9a68af70c58ac415e51e6b12cd85f9af2.diff
LOG: [mlir][linalg] Start a named ops to generic ops pass
This commit starts a new pass and patterns for converting Linalg
named ops to generic ops. This enables us to leverage the flexbility
from generic ops during transformations. Right now only linalg.conv
is supported; others will be added when useful.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D91357
Added:
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2c200fe08b10..6ac1e5642789 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -159,6 +159,10 @@ def CopyOp : LinalgStructured_Op<"copy", [
Value getSource() { return input();}
Value getTarget() { return output(); }
+
+ static std::function<void(Block &)> getRegionBuilder() {
+ return nullptr;
+ }
}];
let verifier = [{ return ::verify(*this); }];
@@ -188,6 +192,10 @@ def FillOp : LinalgStructured_Op<"fill", [
return Builder(getContext()).getAffineMapArrayAttr({
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
}
+
+ static std::function<void(Block &)> getRegionBuilder() {
+ return nullptr;
+ }
}];
let verifier = [{ return ::verify(*this); }];
@@ -261,6 +269,10 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
if (!padding().hasValue()) return 0;
return padding().getValue().getValue<int64_t>({i, 1});
}
+
+ static std::function<void(Block &)> getRegionBuilder() {
+ return nullptr;
+ }
}];
}
@@ -516,6 +528,10 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
return ss.hasValue() ?
llvm::Optional<unsigned>(ss.getValue()) : llvm::None;
}
+
+ static std::function<void(Block &)> getRegionBuilder() {
+ return nullptr;
+ }
}];
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseGenericOp(parser, result); }];
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index ec7167485104..0373bf3f6adf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -803,6 +803,17 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
&res->getRegion(ridx), map);
return res;
}]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/[{
+ Returns the region builder for constructing the body for linalg.generic.
+ Returns a null function if this named op does not define a region
+ builder.
+ }],
+ /*retTy=*/"std::function<void(Block &)>",
+ /*methodName=*/"getRegionBuilder",
+ (ins),
+ [{ return ConcreteOp::getRegionBuilder(); }]
>
];
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 620abd96636d..d041df86d169 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -55,6 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
void populateElementwiseToLinalgConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);
+/// Create a pass to conver named Linalg operations to Linalg generic
+/// operations.
+std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
+
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 07bf93a70f5e..aabfd44299d5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -112,4 +112,10 @@ def LinalgTilingToParallelLoops
let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
}
+def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> {
+ let summary = "Convert named ops into generic ops";
+ let constructor = "mlir::createLinalgGeneralizationPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 523a34e3f613..8d531a1e343a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -624,6 +624,20 @@ struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringType loweringType;
};
+/// Linalg generalization patterns
+
+/// Populates `patterns` with patterns to convert spec-generated named ops to
+/// linalg.generic ops.
+void populateLinalgNamedOpsGeneralizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ LinalgMarker marker = LinalgMarker());
+
+/// Populates `patterns` with patterns to convert linalg.conv ops to
+/// linalg.generic ops.
+void populateLinalgConvGeneralizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ LinalgMarker marker = LinalgMarker());
+
//===----------------------------------------------------------------------===//
// Op-specific patterns.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 11a48894d9d8..6de4ce6ac341 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
ElementwiseToLinalg.cpp
Fusion.cpp
FusionOnTensors.cpp
+ Generalization.cpp
Hoisting.cpp
Interchange.cpp
Loops.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
new file mode 100644
index 000000000000..3496a7796988
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -0,0 +1,180 @@
+//===- Generalization.cpp - linalg named ops to generic ops --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the Linalg generalization pass. It converts named
+// Linalg ops to linalg.generic ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "linalg-generalization"
+
+using namespace mlir;
+
+// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
+// the given `namedOp` does not have a region builder.
+static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
+ OpBuilder &builder) {
+ auto regionBuilder = namedOp.getRegionBuilder();
+ if (!regionBuilder) {
+ LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
+ return nullptr;
+ }
+
+ SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
+ auto iterators = llvm::to_vector<4>(
+ namedOp.iterator_types().getAsValueRange<StringAttr>());
+ auto resultTypes = namedOp.getOutputTensorTypes();
+ SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());
+
+ return builder.create<linalg::GenericOp>(
+ namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(),
+ namedOp.getInitTensors(), indexingMaps, iterators,
+ [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
+ edsc::ScopedContext scope(bodyBuilder, loc);
+ regionBuilder(*bodyBuilder.getBlock());
+ });
+}
+
+namespace {
+
+/// Base class for all linalg generalization patterns. A subclass must provide
+/// the following method:
+/// linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
+/// for creating the generic op.
+// TODO: remove this pattern after migrating all manually-written named ops
+// into auto-generated ones.
+template <typename ConcretePattern, typename RootOp>
+struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
+ LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}
+
+ LogicalResult matchAndRewrite(RootOp rootOp,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
+ if (!linalgOp)
+ return failure();
+ if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+ return failure();
+
+ auto *pattern = static_cast<const ConcretePattern *>(this);
+ linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
+ if (!genericOp)
+ return failure();
+
+ rewriter.replaceOp(rootOp, genericOp.getResults());
+ marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+ return success();
+ }
+
+private:
+ linalg::LinalgMarker marker;
+};
+
+struct GeneralizeConvOp
+ : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
+ using LinalgGeneralizationPattern::LinalgGeneralizationPattern;
+
+ linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
+};
+
+/// Catch-all pattern for converting all named ops with a region builder into
+/// linalg.generic.
+struct LinalgNamedOpGeneralizationPattern : RewritePattern {
+ LinalgNamedOpGeneralizationPattern(MLIRContext *context,
+ linalg::LinalgMarker marker,
+ PatternBenefit benefit = 1)
+ : RewritePattern(benefit, MatchAnyOpTypeTag()),
+ marker(std::move(marker)) {}
+
+ LogicalResult matchAndRewrite(Operation *rootOp,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
+ if (!linalgOp)
+ return failure();
+ if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+ return failure();
+
+ // No nothing to do for linalg.generic and linalg.indexed_generic.
+ if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
+ return failure();
+
+ linalg::GenericOp genericOp =
+ createGenericOpFromNamedOp(linalgOp, rewriter);
+ if (!genericOp)
+ return failure();
+
+ rewriter.replaceOp(rootOp, genericOp.getResults());
+ marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
+ return success();
+ }
+
+private:
+ linalg::LinalgMarker marker;
+};
+
+struct LinalgGeneralizationPass
+ : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
+ void runOnFunction() override;
+};
+
+} // namespace
+
+void LinalgGeneralizationPass::runOnFunction() {
+ FuncOp func = getFunction();
+ OwningRewritePatternList patterns;
+ linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
+ linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
+ applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
+}
+
+linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
+ OpBuilder &builder) const {
+ SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
+ auto iterators =
+ llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
+ return builder.create<linalg::GenericOp>(
+ convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
+ convOp.getInputBuffers(), convOp.getOutputBuffers(),
+ /*initTensors=*/ValueRange(), indexingMaps, iterators,
+ [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
+ Value mul =
+ bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
+ Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
+ bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
+ });
+}
+
+void mlir::linalg::populateLinalgConvGeneralizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ linalg::LinalgMarker marker) {
+ patterns.insert<GeneralizeConvOp>(context, marker);
+}
+
+void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns,
+ linalg::LinalgMarker marker) {
+ patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
+}
+
+std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
+ return std::make_unique<LinalgGeneralizationPass>();
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
new file mode 100644
index 000000000000..966024395ced
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
+
+func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) {
+ linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32>
+ return
+}
+
+// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 4 + d3 * 2, d2 * 5 + d4 * 3, d5)>
+// CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d6)>
+
+// CHECK: func @generalize_conv
+// CHECK-SAME: %[[INPUT:.+]]: memref<1x225x225x3xf32>
+// CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32>
+// CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32>
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[FILTER_MAP]], #[[INPUT_MAP]], #[[OUTPUT_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "window", "window", "reduction", "parallel"]
+// CHECK-SAME: ins(%[[FILTER]], %[[INPUT]]
+// CHECK-SAME: outs(%[[OUTPUT]]
+
+// CHECK: ^{{.*}}(%[[FILTER_ARG:.+]]: f32, %[[INPUT_ARG:.+]]: f32, %[[OUTPUT_ARG:.+]]: f32)
+// CHECK: %[[MUL:.+]] = mulf %[[FILTER_ARG]], %[[INPUT_ARG]]
+// CHECK: %[[ADD:.+]] = addf %[[MUL]], %[[OUTPUT_ARG]]
+// CHECK: linalg.yield %[[ADD]]
+
+// -----
+
+func @generalize_matmul_buffer(%A : memref<16x8xf32>, %B: memref<8x32xf32>, %C: memref<16x32xf32>) {
+ linalg.matmul ins(%A, %B: memref<16x8xf32>, memref<8x32xf32>) outs(%C: memref<16x32xf32>)
+ return
+}
+
+
+// CHECK: #[[A_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[B_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[C_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: func @generalize_matmul_buffer
+// CHECK-SAME: %[[A:.+]]: memref<16x8xf32>
+// CHECK-SAME: %[[B:.+]]: memref<8x32xf32>
+// CHECK-SAME: %[[C:.+]]: memref<16x32xf32>
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[A_MAP]], #[[B_MAP]], #[[C_MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[A]], %[[B]]
+// CHECK-SAME: outs(%[[C]]
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
+// -----
+
+func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) init(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK: func @generalize_matmul_tensor
+
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<16x8xf32>, tensor<8x32xf32>)
+// CHECK-SAME: init(%{{.+}} : tensor<16x32xf32>)
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index e7e5ef8901b8..45dc115e6c1e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1522,6 +1522,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(Block &block);
+ static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
// Generic methods.
static unsigned getNumRegionArgs() {{ return {4}; }
More information about the Mlir-commits
mailing list