[Mlir-commits] [mlir] 67e0d58 - [MLIR][LinAlg] Start detensoring implementation.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 22 23:28:10 PST 2021


Author: KareemErgawy-TomTom
Date: 2021-02-23T08:27:58+01:00
New Revision: 67e0d58de4d338add132838810db70218f1064d8

URL: https://github.com/llvm/llvm-project/commit/67e0d58de4d338add132838810db70218f1064d8
DIFF: https://github.com/llvm/llvm-project/commit/67e0d58de4d338add132838810db70218f1064d8.diff

LOG: [MLIR][LinAlg] Start detensoring implementation.

This commit is the first baby step towards detensoring in
linalg-on-tensors.

Detensoring is the process through which a tensor value is convereted to one
or potentially more primitive value(s). During this process, operations with
such detensored operands are also converted to an equivalen form that works
on primitives.

The detensoring process is driven by linalg-on-tensor ops. In particular, a
linalg-on-tensor op is checked to see whether *all* its operands can be
detensored. If so, those operands are converted to thier primitive
counterparts and the linalg op is replaced by an equivalent op that takes
those new primitive values as operands.

This works towards handling github/google/iree#1159.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96271

Added: 
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/test/Dialect/Linalg/detensorized_0d.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 5d68328acc7e..7d93dd00d86a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -59,6 +59,10 @@ void populateElementwiseToLinalgConversionPatterns(
 /// operations.
 std::unique_ptr<OperationPass<FuncOp>> createLinalgGeneralizationPass();
 
+/// Create a pass to convert Linalg operations to equivalent operations that
+/// work on primitive types, if possible.
+std::unique_ptr<Pass> createLinalgDetensorizePass();
+
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index a20289af3054..e51d08d3770d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -136,4 +136,28 @@ def LinalgGeneralization : FunctionPass<"linalg-generalize-named-ops"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgDetensorize : FunctionPass<"linalg-detensorize"> {
+  let summary = "Detensorize linalg ops";
+  let constructor = "mlir::createLinalgDetensorizePass()";
+  let dependentDialects = [];
+
+  let description = [{
+    Detensoring is the process through which a tensor value is convereted to one
+    or potentially more primitive value(s). During this process, operations with
+    such detensored operands are also converted to an equivalent form that works
+    on primitives.
+
+    The detensoring process is driven by linalg-on-tensor ops. In particular, a
+    linalg-on-tensor op is checked to see whether *all* its operands can be
+    detensored. If so, those operands are converted to their primitive
+    counterparts and the linalg op is replaced by an equivalent op that takes
+    those new primitive values as operands. Therefore, the detensoring process
+    can be divided into 2 main logical phases:
+
+    1. Detect/match an op that can be detensored.
+    2. Detensor the operands of the op and replace it with a primitive
+       equivalent.
+  }];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index d988e245c9f7..1469371e1466 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRLinalgTransforms
   Bufferize.cpp
   CodegenStrategy.cpp
+  Detensorize.cpp
   DropUnitDims.cpp
   ElementwiseToLinalg.cpp
   Fusion.cpp

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
new file mode 100644
index 000000000000..2e2e3b94a34a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -0,0 +1,173 @@
+//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <iterator>
+#include <memory>
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Defines the criteria a TensorType must follow in order to be considered
+/// "detensorable".
+///
+/// NOTE: For now, only 0-D are supported.
+///
+/// Returns true if tensorType can be detensored.
+bool canBeDetensored(TensorType tensorType) {
+  return tensorType.hasRank() && tensorType.getRank() == 0;
+}
+
+/// A conversion patttern for detensoring `linalg.generic` ops.
+class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(GenericOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Block *originalBlock = op->getBlock();
+
+    // Gather some information about the op before inling its region.
+    Block *opEntryBlock = &*op.region().begin();
+    YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
+
+    // Split the op's region before the op. This way, we have a clear insertion
+    // point in which the op can be inlined.
+    Block *newBlock = originalBlock->splitBlock(op);
+    rewriter.inlineRegionBefore(op.region(), newBlock);
+    // Now that op's region is inlined, the operands of its YieldOp are mapped
+    // to the materialized target values. Therefore, we can replace the op's
+    // uses with those of its YielOp's operands.
+    rewriter.replaceOp(op, yieldOp->getOperands());
+
+    // No need for these intermediate blocks, merge them into 1.
+    rewriter.mergeBlocks(opEntryBlock, originalBlock, operands);
+    rewriter.mergeBlocks(newBlock, originalBlock, {});
+
+    rewriter.eraseOp(&*Block::iterator(yieldOp));
+
+    return success();
+  }
+};
+
+class DetensorizeTypeConverter : public TypeConverter {
+public:
+  DetensorizeTypeConverter() {
+    addConversion([](Type type) { return type; });
+
+    // A TensorType that can be detensored, is converted to the underlying
+    // element type.
+    addConversion([](TensorType tensorType) -> Type {
+      if (canBeDetensored(tensorType))
+        return tensorType.getElementType();
+
+      return tensorType;
+    });
+
+    // A tensor value is detensoried by extracting its element(s).
+    addTargetMaterialization([](OpBuilder &builder, Type type,
+                                ValueRange inputs, Location loc) -> Value {
+      return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
+    });
+
+    // A detensored value is converted back by creating a new tensor from its
+    // element(s).
+    addSourceMaterialization([](OpBuilder &builder, Type type,
+                                ValueRange inputs, Location loc) -> Value {
+      auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
+          loc, inputs[0].getType(), inputs[0]);
+
+      // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
+      // a tensor<dtype> instead.
+      return builder.create<linalg::TensorReshapeOp>(
+          loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
+    });
+  }
+};
+
+/// Canonicalizes the pattern of the form
+///
+/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
+/// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into
+///   tensor<i32>
+/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
+///
+/// to just %element.
+struct ExtractFromReshapeFromElements
+    : public OpRewritePattern<tensor::ExtractOp> {
+  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
+                                PatternRewriter &rewriter) const final {
+    if (extract.indices().size() != 0)
+      return failure();
+
+    auto tensorReshape = extract.tensor().getDefiningOp<TensorReshapeOp>();
+    if (tensorReshape == nullptr)
+      return failure();
+
+    auto tensorFromElements =
+        tensorReshape.getOperand()
+            .getDefiningOp<mlir::tensor::FromElementsOp>();
+    if (tensorFromElements == nullptr)
+      return failure();
+
+    rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
+    return success();
+  }
+};
+
+/// @see LinalgDetensorize in Linalg/Passes.td for more details.
+struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
+  void runOnFunction() override {
+    auto *context = &getContext();
+    DetensorizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+    target.addLegalDialect<linalg::LinalgDialect>();
+    target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
+      // If any of the operands or results cannot be detensored, the op is
+      // considered legal and won't be detensored.
+      return llvm::any_of(
+          op.getShapedOperandTypes(), [](ShapedType shapedType) {
+            assert(shapedType.isa<TensorType>());
+            return !canBeDetensored(shapedType.cast<TensorType>());
+          });
+    });
+
+    patterns.insert<DetensorizeGenericOp>(typeConverter, context);
+
+    if (failed(
+            applyPartialConversion(getFunction(), target, std::move(patterns))))
+      signalPassFailure();
+
+    OwningRewritePatternList canonPatterns;
+    canonPatterns.insert<ExtractFromReshapeFromElements>(context);
+    if (failed(applyPatternsAndFoldGreedily(getFunction(),
+                                            std::move(canonPatterns))))
+      signalPassFailure();
+
+    // TODO Properly handle control flow within function boundaries.
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
+  return std::make_unique<LinalgDetensorize>();
+}

diff  --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
new file mode 100644
index 000000000000..e35a34fd157d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
@@ -0,0 +1,107 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+
+#map = affine_map<() -> ()>
+
+func @detensor_simple(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+  %0 = linalg.init_tensor [] : tensor<f32>
+  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+    ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+    outs(%0 : tensor<f32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+    %2 = addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<f32>
+  return %1: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_simple
+// CHECK-SAME:    (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG:     %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK:         %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
+// CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         return %[[reshaped_tensor_res]]
+
+func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+  %0 = linalg.init_tensor [] : tensor<f32>
+  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+    ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+    outs(%0 : tensor<f32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+    %2 = addf %arg3, %arg4 : f32
+    linalg.yield %2 : f32
+  } -> tensor<f32>
+
+  %3 = linalg.init_tensor [] : tensor<f32>
+  %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+    ins(%arg1, %1 : tensor<f32>, tensor<f32>)
+    outs(%3 : tensor<f32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+    %5 = mulf %arg3, %arg4 : f32
+    linalg.yield %5 : f32
+  } -> tensor<f32>
+
+  %6 = linalg.init_tensor [] : tensor<f32>
+  %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+    ins(%1, %4 : tensor<f32>, tensor<f32>)
+    outs(%6 : tensor<f32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+    %5 = divf %arg3, %arg4 : f32
+    linalg.yield %5 : f32
+  } -> tensor<f32>
+
+  return %7: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_op_sequence
+// CHECK-SAME:    (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG:     %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK:         %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
+// CHECK-DAG:     %[[arg1_val2:.*]] = tensor.extract %[[arg1]]
+// CHECK:         %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]]
+// CHECK:         %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]]
+// CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         return %[[reshaped_tensor_res]]
+
+func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+  %0 = linalg.init_tensor [] : tensor<f32>
+  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+    ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+    outs(%0 : tensor<f32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+    %2 = addf %arg3, %arg4 : f32
+    %3 = mulf %2, %arg4 : f32
+    linalg.yield %3 : f32
+  } -> tensor<f32>
+  return %1: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_multiple_ops
+// CHECK-SAME:    (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG:     %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK:         %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]]
+// CHECK:         %[[detensored_res2:.*]] = mulf %[[detensored_res]], %[[arg2_val]]
+// CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         return %[[reshaped_tensor_res]]
+
+func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
+  %0 = linalg.init_tensor [] : tensor<f32>
+  %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []}
+    ins(%arg1, %arg2 : tensor<f32>, tensor<f32>)
+    outs(%0 : tensor<f32>) {
+  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):  // no predecessors
+    %2 = "foreign.do_something"(%arg3, %arg4) {} : (f32, f32) -> f32
+    linalg.yield %2 : f32
+  } -> tensor<f32>
+  return %1: tensor<f32>
+}
+// CHECK-LABEL: func @detensor_foreign_op
+// CHECK-SAME:    (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
+// CHECK-DAG:     %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
+// CHECK:         %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]])
+// CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
+// CHECK:         %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]]
+// CHECK:         return %[[reshaped_tensor_res]]


        


More information about the Mlir-commits mailing list