[Mlir-commits] [mlir] 53a0d45 - [mlir] Add pass to convert elementwise ops to linalg.
Sean Silva
llvmlistbot at llvm.org
Tue Nov 10 13:46:34 PST 2020
Author: Sean Silva
Date: 2020-11-10T13:44:44-08:00
New Revision: 53a0d45db6d0f33dfbb724c99ce2560ae25473c2
URL: https://github.com/llvm/llvm-project/commit/53a0d45db6d0f33dfbb724c99ce2560ae25473c2
DIFF: https://github.com/llvm/llvm-project/commit/53a0d45db6d0f33dfbb724c99ce2560ae25473c2.diff
LOG: [mlir] Add pass to convert elementwise ops to linalg.
This patch converts elementwise ops on tensors to linalg.generic ops
with the same elementwise op in the payload (except rewritten to
operate on scalars, obviously). This is a great form for later fusion to
clean up.
E.g.
```
// Compute: %arg0 + %arg1 - %arg2
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
%0 = addf %arg0, %arg1 : tensor<?xf32>
%1 = subf %0, %arg2 : tensor<?xf32>
return %1 : tensor<?xf32>
}
```
Running this through
`mlir-opt -convert-std-to-linalg -linalg-fusion-for-tensor-ops` we get:
```
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
%0 = linalg.generic {indexing_maps = [#map0, #map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg2 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%1 = addf %arg3, %arg4 : f32
%2 = subf %1, %arg5 : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
return %0 : tensor<?xf32>
}
```
So the elementwise ops on tensors have nicely collapsed into a single
linalg.generic, which is the form we want for further transformations.
Differential Revision: https://reviews.llvm.org/D90354
Added:
mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 24570d3c4ec6..50aec73366a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -16,6 +16,8 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
+std::unique_ptr<OperationPass<FuncOp>> createConvertElementwiseToLinalgPass();
+
std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
@@ -48,6 +50,11 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// buffers instead.
std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass();
+/// Populate patterns that convert `ElementwiseMappable` ops to linalg
+/// parallel loops.
+void populateElementwiseToLinalgConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx);
+
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 7446ca8f6636..9162543a310c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -11,6 +11,17 @@
include "mlir/Pass/PassBase.td"
+def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
+ let summary = "Convert ElementwiseMappable ops to linalg";
+ let description = [{
+ Convert ops with the `ElementwiseMappable` trait to linalg parallel loops.
+
+ This pass only converts ops that operate on ranked tensors.
+ }];
+ let constructor = "mlir::createConvertElementwiseToLinalgPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
new file mode 100644
index 000000000000..d26b8f75cf28
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+ %a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32>
+ %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
+
+ %addf = addf %a, %b : tensor<3xf32>
+ %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
+ call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
+ // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
+ // CHECK-NEXT: [11, 22, 33]
+
+ return
+}
+
+func @print_memref_f32(%ptr : tensor<*xf32>)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 88242c1d6f28..73df73e83d82 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Bufferize.cpp
CodegenStrategy.cpp
DropUnitDims.cpp
+ ElementwiseToLinalg.cpp
Fusion.cpp
FusionOnTensors.cpp
Hoisting.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
new file mode 100644
index 000000000000..a0e5d74e1767
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -0,0 +1,98 @@
+//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
+ if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+ return false;
+
+ // TODO: The conversion pattern can be made to work for `any_of` here, but
+ // it's more complex as it requires tracking which operands are scalars.
+ return llvm::all_of(op->getOperandTypes(),
+ [](Type type) { return type.isa<RankedTensorType>(); });
+}
+
+namespace {
+struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern {
+ ConvertStdElementwiseOpOnRankedTensors()
+ : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
+ if (!isElementwiseMappableOpOnRankedTensors(op))
+ return rewriter.notifyMatchFailure(
+ op, "requires elementwise op on ranked tensors");
+
+ auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
+ SmallVector<AffineMap, 3> indexingMaps(
+ op->getNumResults() + op->getNumOperands(),
+ rewriter.getMultiDimIdentityMap(rank));
+ SmallVector<StringRef, 6> iteratorTypes(rank,
+ getParallelIteratorTypeName());
+ rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+ op, /*resultTensorTypes=*/op->getResultTypes(),
+ /*inputs=*/op->getOperands(),
+ /*outputBuffers=*/ValueRange(),
+ /*initTensors=*/ValueRange(),
+ /*indexingMaps=*/indexingMaps,
+ /*iteratorTypes=*/iteratorTypes,
+ /*bodyBuilder=*/
+ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
+ OperationState state(loc, op->getName());
+ state.addAttributes(op->getAttrs());
+ state.addOperands(regionArgs);
+ auto resultTypes = llvm::to_vector<6>(
+ llvm::map_range(op->getResultTypes(), [](Type type) {
+ return type.cast<TensorType>().getElementType();
+ }));
+ state.addTypes(resultTypes);
+ auto *scalarOp = builder.createOperation(state);
+ builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
+ });
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateElementwiseToLinalgConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *) {
+ patterns.insert<ConvertStdElementwiseOpOnRankedTensors>();
+}
+
+namespace {
+class ConvertElementwiseToLinalgPass
+ : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
+
+ void runOnFunction() final {
+ auto func = getOperation();
+ auto *context = &getContext();
+ ConversionTarget target(*context);
+ OwningRewritePatternList patterns;
+
+ populateElementwiseToLinalgConversionPatterns(patterns, context);
+ target.markUnknownOpDynamicallyLegal([](Operation *op) {
+ return !isElementwiseMappableOpOnRankedTensors(op);
+ });
+
+ if (failed(applyPartialConversion(func, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createConvertElementwiseToLinalgPass() {
+ return std::make_unique<ConvertElementwiseToLinalgPass>();
+}
diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
new file mode 100644
index 000000000000..7ea78fef7add
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s
+
+// In-depth checking of the linalg.generic op for a very trivial case.
+// CHECK: #map = affine_map<() -> ()>
+// CHECK-LABEL: func @addf_rank0
+func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
+ // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor<f32>, tensor<f32>) {
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+ // CHECK: %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32
+ // CHECK: linalg.yield %[[YIELD]] : f32
+ // CHECK: } -> tensor<f32>
+ %0 = addf %arg0, %arg1 : tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// Check indexing maps and iterator types for the rank > 0 case.
+// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @addf_rank1
+func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
+ %0 = addf %arg0, %arg1 : tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Check a unary op.
+// CHECK-LABEL: func @exp
+func @exp(%arg0: tensor<f32>) -> tensor<f32> {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[SCALAR:.*]]: f32):
+ // CHECK: %[[YIELD:.*]] = exp %[[SCALAR]] : f32
+ // CHECK: linalg.yield %[[YIELD]] : f32
+ %0 = exp %arg0 : tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// Check a case with varying operand types.
+// CHECK-LABEL: func @select
+func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32):
+ // CHECK: select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32
+ %0 = select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32>
+ return %0 : tensor<i32>
+}
+
+// -----
+
+// Spot-check an op that requires copying attributes properly to the created scalar op.
+// CHECK-LABEL: func @cmpf(
+func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
+ // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32
+ %0 = cmpf "olt", %arg0, %arg1 : tensor<f32>
+ return %0 : tensor<i1>
+}
More information about the Mlir-commits
mailing list