[Mlir-commits] [mlir] 53a0d45 - [mlir] Add pass to convert elementwise ops to linalg.

Sean Silva llvmlistbot at llvm.org
Tue Nov 10 13:46:34 PST 2020


Author: Sean Silva
Date: 2020-11-10T13:44:44-08:00
New Revision: 53a0d45db6d0f33dfbb724c99ce2560ae25473c2

URL: https://github.com/llvm/llvm-project/commit/53a0d45db6d0f33dfbb724c99ce2560ae25473c2
DIFF: https://github.com/llvm/llvm-project/commit/53a0d45db6d0f33dfbb724c99ce2560ae25473c2.diff

LOG: [mlir] Add pass to convert elementwise ops to linalg.

This patch converts elementwise ops on tensors to linalg.generic ops
with the same elementwise op in the payload (except rewritten to
operate on scalars, obviously). This is a great form for later fusion to
clean up.

E.g.

```
// Compute: %arg0 + %arg1 - %arg2
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
  %0 = addf %arg0, %arg1 : tensor<?xf32>
  %1 = subf %0, %arg2 : tensor<?xf32>
  return %1 : tensor<?xf32>
}
```

Running this through
`mlir-opt -convert-std-to-linalg -linalg-fusion-for-tensor-ops` we get:

```
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
  %0 = linalg.generic {indexing_maps = [#map0, #map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg2 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
    %1 = addf %arg3, %arg4 : f32
    %2 = subf %1, %arg5 : f32
    linalg.yield %2 : f32
  } -> tensor<?xf32>
  return %0 : tensor<?xf32>
}
```

So the elementwise ops on tensors have nicely collapsed into a single
linalg.generic, which is the form we want for further transformations.

Differential Revision: https://reviews.llvm.org/D90354

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
    mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 24570d3c4ec6..50aec73366a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -16,6 +16,8 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+std::unique_ptr<OperationPass<FuncOp>> createConvertElementwiseToLinalgPass();
+
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
 
 std::unique_ptr<Pass> createLinalgFusionOfTensorOpsPass();
@@ -48,6 +50,11 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
 /// buffers instead.
 std::unique_ptr<OperationPass<ModuleOp>> createLinalgBufferizePass();
 
+/// Populate patterns that convert `ElementwiseMappable` ops to linalg
+/// parallel loops.
+void populateElementwiseToLinalgConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 7446ca8f6636..9162543a310c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -11,6 +11,17 @@
 
 include "mlir/Pass/PassBase.td"
 
+def ConvertElementwiseToLinalg : FunctionPass<"convert-elementwise-to-linalg"> {
+  let summary = "Convert ElementwiseMappable ops to linalg";
+  let description = [{
+    Convert ops with the `ElementwiseMappable` trait to linalg parallel loops.
+
+    This pass only converts ops that operate on ranked tensors.
+  }];
+  let constructor = "mlir::createConvertElementwiseToLinalgPass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
   let summary = "Remove unit-extent dimension in Linalg ops on tensors";
   let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
new file mode 100644
index 000000000000..d26b8f75cf28
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+  %a = constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32>
+  %b = constant dense<[10.0, 20.0, 30.0]> : tensor<3xf32>
+
+  %addf = addf %a, %b : tensor<3xf32>
+  %addf_unranked = tensor_cast %addf : tensor<3xf32> to tensor<*xf32>
+  call @print_memref_f32(%addf_unranked) : (tensor<*xf32>) -> ()
+  // CHECK: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [3] strides = [1] data =
+  // CHECK-NEXT: [11,  22,  33]
+
+  return
+}
+
+func @print_memref_f32(%ptr : tensor<*xf32>)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 88242c1d6f28..73df73e83d82 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Bufferize.cpp
   CodegenStrategy.cpp
   DropUnitDims.cpp
+  ElementwiseToLinalg.cpp
   Fusion.cpp
   FusionOnTensors.cpp
   Hoisting.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
new file mode 100644
index 000000000000..a0e5d74e1767
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -0,0 +1,98 @@
+//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
+  if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+    return false;
+
+  // TODO: The conversion pattern can be made to work for `any_of` here, but
+  // it's more complex as it requires tracking which operands are scalars.
+  return llvm::all_of(op->getOperandTypes(),
+                      [](Type type) { return type.isa<RankedTensorType>(); });
+}
+
+namespace {
+struct ConvertStdElementwiseOpOnRankedTensors : public RewritePattern {
+  ConvertStdElementwiseOpOnRankedTensors()
+      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const final {
+    if (!isElementwiseMappableOpOnRankedTensors(op))
+      return rewriter.notifyMatchFailure(
+          op, "requires elementwise op on ranked tensors");
+
+    auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
+    SmallVector<AffineMap, 3> indexingMaps(
+        op->getNumResults() + op->getNumOperands(),
+        rewriter.getMultiDimIdentityMap(rank));
+    SmallVector<StringRef, 6> iteratorTypes(rank,
+                                            getParallelIteratorTypeName());
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        op, /*resultTensorTypes=*/op->getResultTypes(),
+        /*inputs=*/op->getOperands(),
+        /*outputBuffers=*/ValueRange(),
+        /*initTensors=*/ValueRange(),
+        /*indexingMaps=*/indexingMaps,
+        /*iteratorTypes=*/iteratorTypes,
+        /*bodyBuilder=*/
+        [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
+          OperationState state(loc, op->getName());
+          state.addAttributes(op->getAttrs());
+          state.addOperands(regionArgs);
+          auto resultTypes = llvm::to_vector<6>(
+              llvm::map_range(op->getResultTypes(), [](Type type) {
+                return type.cast<TensorType>().getElementType();
+              }));
+          state.addTypes(resultTypes);
+          auto *scalarOp = builder.createOperation(state);
+          builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
+        });
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateElementwiseToLinalgConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *) {
+  patterns.insert<ConvertStdElementwiseOpOnRankedTensors>();
+}
+
+namespace {
+class ConvertElementwiseToLinalgPass
+    : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
+
+  void runOnFunction() final {
+    auto func = getOperation();
+    auto *context = &getContext();
+    ConversionTarget target(*context);
+    OwningRewritePatternList patterns;
+
+    populateElementwiseToLinalgConversionPatterns(patterns, context);
+    target.markUnknownOpDynamicallyLegal([](Operation *op) {
+      return !isElementwiseMappableOpOnRankedTensors(op);
+    });
+
+    if (failed(applyPartialConversion(func, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<FuncOp>>
+mlir::createConvertElementwiseToLinalgPass() {
+  return std::make_unique<ConvertElementwiseToLinalgPass>();
+}

diff  --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
new file mode 100644
index 000000000000..7ea78fef7add
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt -convert-elementwise-to-linalg -split-input-file %s | FileCheck %s
+
+// In-depth checking of the linalg.generic op for a very trivial case.
+// CHECK: #map = affine_map<() -> ()>
+// CHECK-LABEL:   func @addf_rank0
+func @addf_rank0(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
+  // CHECK: %{{.*}} = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%{{.*}}, %{{.*}} : tensor<f32>, tensor<f32>) {
+  // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
+  // CHECK:   %[[YIELD:.*]] = addf %[[LHS]], %[[RHS]] : f32
+  // CHECK:   linalg.yield %[[YIELD]] : f32
+  // CHECK: } -> tensor<f32>
+  %0 = addf %arg0, %arg1 : tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// Check indexing maps and iterator types for the rank > 0 case.
+// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @addf_rank1
+func @addf_rank1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK: linalg.generic{{.*}}indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]
+  %0 = addf %arg0, %arg1 : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+// Check a unary op.
+// CHECK-LABEL: func @exp
+func @exp(%arg0: tensor<f32>) -> tensor<f32> {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[SCALAR:.*]]: f32):
+  // CHECK:   %[[YIELD:.*]] = exp %[[SCALAR]] : f32
+  // CHECK:   linalg.yield %[[YIELD]] : f32
+  %0 = exp %arg0 : tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// Check a case with varying operand types.
+// CHECK-LABEL: func @select
+func @select(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i32> {
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[PRED:.*]]: i1, %[[TRUE_VAL:.*]]: i32, %[[FALSE_VAL:.*]]: i32):
+  // CHECK:   select %[[PRED]], %[[TRUE_VAL]], %[[FALSE_VAL]] : i32
+  %0 = select %arg0, %arg1, %arg2 : tensor<i1>, tensor<i32>
+  return %0 : tensor<i32>
+}
+
+// -----
+
+// Spot-check an op that requires copying attributes properly to the created scalar op.
+// CHECK-LABEL: func @cmpf(
+func @cmpf(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<i1> {
+  // CHECK: cmpf "olt", %{{.*}}, %{{.*}} : f32
+  %0 = cmpf "olt", %arg0, %arg1 : tensor<f32>
+  return %0 : tensor<i1>
+}


        


More information about the Mlir-commits mailing list