[Mlir-commits] [mlir] caccddc - [MLIR][TOSA] Lower tosa.reshape to linalg.reshape

Rob Suderman llvmlistbot at llvm.org
Fri Feb 26 13:03:54 PST 2021


Author: Rob Suderman
Date: 2021-02-26T12:57:57-08:00
New Revision: caccddc52a33b246d6b44143b0e8c60cc908a3ab

URL: https://github.com/llvm/llvm-project/commit/caccddc52a33b246d6b44143b0e8c60cc908a3ab
DIFF: https://github.com/llvm/llvm-project/commit/caccddc52a33b246d6b44143b0e8c60cc908a3ab.diff

LOG: [MLIR][TOSA] Lower tosa.reshape to linalg.reshape

Lowering from the tosa.reshape op to linalg.reshape. For same-rank or
non-collapsed/expanded cases two linalg.reshapes are inserted.

Differential Revision: https://reviews.llvm.org/D97439

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8e096e48d2d3..75bbc46c4804 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -16,8 +16,11 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#include <numeric>
+
 using namespace mlir;
 
 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
@@ -339,6 +342,106 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
+class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
+public:
+  using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
+                  ConversionPatternRewriter &rewriter) const final {
+    typename tosa::ReshapeOp::Adaptor operands(args);
+
+    ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
+    ShapedType resultTy = reshape.getType().template cast<ShapedType>();
+
+    if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
+      return failure();
+
+    // Compute the reassociation maps for the linalg operation.
+    ArrayRef<int64_t> expandedShape =
+        (operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
+                                                  : resultTy.getShape());
+    ArrayRef<int64_t> collapsedShape =
+        (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
+                                                  : operandTy.getShape());
+    unsigned currSrcDim = 0, currDstDim = 0;
+    SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
+        collapsedShape.size());
+
+    // First scan all dimensions in the source shapes to see whether we have a
+    // perfect case where consecutive dimensions in source are collapsed. For
+    // such case we can just generate one single linalg.reshape.
+    bool isCollapsingSource = true;
+    while (currSrcDim < expandedShape.size() &&
+           currDstDim < collapsedShape.size()) {
+      int64_t dstSize = collapsedShape[currDstDim];
+      int64_t srcSize = expandedShape[currSrcDim];
+      while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
+        reassociationMap[currDstDim].push_back(
+            rewriter.getAffineDimExpr(currSrcDim++));
+        srcSize *= expandedShape[currSrcDim];
+      }
+      if (srcSize == dstSize) {
+        reassociationMap[currDstDim].push_back(
+            rewriter.getAffineDimExpr(currSrcDim++));
+        // If the next dim in collapsedShape is not 1, treat subsequent dims in
+        // expandedShape which are 1 to be collapsed.
+        if (currDstDim == collapsedShape.size() - 1 ||
+            collapsedShape[currDstDim + 1] != 1) {
+          while (currSrcDim < expandedShape.size() &&
+                 expandedShape[currSrcDim] == 1) {
+            reassociationMap[currDstDim].push_back(
+                rewriter.getAffineDimExpr(currSrcDim++));
+          }
+        }
+      } else {
+        isCollapsingSource = false;
+        break;
+      }
+      currDstDim++;
+    }
+    if (currSrcDim != expandedShape.size() ||
+        currDstDim != collapsedShape.size())
+      isCollapsingSource = false;
+
+    // Otherwise, we need to first reduce all source dimensions into one and
+    // then expand to the destination dimensions.
+    if (!isCollapsingSource) {
+      auto getIdentityExprs = [&rewriter](int n) {
+        SmallVector<AffineExpr, 4> exprs;
+        for (int i = 0; i < n; ++i)
+          exprs.push_back(rewriter.getAffineDimExpr(i));
+        return exprs;
+      };
+      Location loc = reshape.getLoc();
+      int64_t totalElems =
+          std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
+                          std::multiplies<int64_t>());
+      auto elemTy = operandTy.getElementType();
+      SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
+          // Use operandTy here because we need to collapse all operands
+          // dimensions.
+          getIdentityExprs(operandTy.getShape().size())};
+      SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
+          // Use resultTy here because we need to expand to all result
+          // dimensions.
+          getIdentityExprs(resultTy.getShape().size())};
+
+      auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
+      Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
+          loc, collapsedTy, args[0], collapsingMap);
+      rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
+          reshape, resultTy, collapsedOp, expandingMap);
+
+      return success();
+    }
+
+    rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
+        reshape, resultTy, args[0], reassociationMap);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -358,6 +461,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::GreaterEqualOp>,
       PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
       PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
-      PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>>(
-      context);
+      PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
+      ReshapeOpConverter>(context);
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 022421459d16..985aec3d212b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -258,3 +258,49 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   return
 }
 
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_reshape_downrank
+func @test_reshape_downrank(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [6]} : (tensor<2x3xf32>) -> tensor<6xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<6xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_reshape_uprank
+func @test_reshape_uprank(%arg0: tensor<6xf32>) -> tensor<2x3xf32> {
+  // CHECK: [[RESHAPE:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<6xf32>) -> tensor<2x3xf32>
+  // CHECK: return [[RESHAPE]]
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @test_reshape_samerank
+func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> {
+  // CHECK: [[RESHAPE1:%.+]] = linalg.tensor_reshape %arg0 [#[[$MAP0]]]
+  // CHECK: [[RESHAPE2:%.+]] = linalg.tensor_reshape [[RESHAPE1]] [#[[$MAP0]]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<3x2xf32>) -> tensor<2x3xf32>
+  // CHECK: return [[RESHAPE2]]
+  return %0 : tensor<2x3xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+
+// CHECK-LABEL: @test_reshape_downrank_6D
+func @test_reshape_downrank_6D(%arg0: tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32> {
+  // CHECK: linalg.tensor_reshape %arg0 [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+  %0 = "tosa.reshape"(%arg0) {new_shape = [2, 3]} : (tensor<1x2x3x5x7x11xf32>) -> tensor<6x5x77xf32>
+  return %0 : tensor<6x5x77xf32>
+}


        


More information about the Mlir-commits mailing list