[Mlir-commits] [mlir] 087bc20 - [MLIR][TOSA] Lower tosa.transpose to linalg.generic

Rob Suderman llvmlistbot at llvm.org
Mon Mar 1 11:10:17 PST 2021


Author: Rob Suderman
Date: 2021-03-01T11:09:49-08:00
New Revision: 087bc20fe42f2619fca76818900f73dd2c4a5b94

URL: https://github.com/llvm/llvm-project/commit/087bc20fe42f2619fca76818900f73dd2c4a5b94
DIFF: https://github.com/llvm/llvm-project/commit/087bc20fe42f2619fca76818900f73dd2c4a5b94.diff

LOG: [MLIR][TOSA] Lower tosa.transpose to linalg.generic

Lowers the transpose operation to a generic linalg op when permutations
is a constant value.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D97508

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f6cf7ed709a3..ece9380845b6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -438,6 +439,48 @@ class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
 
     rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
         reshape, resultTy, args[0], reassociationMap);
+
+    return success();
+  }
+};
+
+class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
+public:
+  using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const final {
+    DenseIntElementsAttr perms;
+    if (!matchPattern(op.perms(), m_Constant(&perms))) {
+      return failure();
+    }
+
+    auto resultTy = op.getType().cast<ShapedType>();
+    if (!resultTy.hasStaticShape())
+      return failure();
+
+    SmallVector<AffineExpr, 2> inputExprs;
+    inputExprs.resize(resultTy.getRank());
+    for (auto permutation : llvm::enumerate(perms.getIntValues())) {
+      inputExprs[permutation.value().getZExtValue()] =
+          rewriter.getAffineDimExpr(permutation.index());
+    }
+
+    auto initTensor = rewriter.create<linalg::InitTensorOp>(
+        op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
+        resultTy.getElementType());
+
+    SmallVector<AffineMap, 2> affineMaps = {
+        AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
+                       rewriter.getContext()),
+        rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
+        getNParallelLoopsAttrs(resultTy.getRank()),
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
+        });
     return success();
   }
 };
@@ -478,5 +521,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
       PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
       IdentityNConverter<tosa::IdentityOp>,
-      IdentityNConverter<tosa::IdentityNOp>, ReshapeOpConverter>(context);
+      IdentityNConverter<tosa::IdentityNOp>,
+      ReshapeOpConverter, TransposeConverter>(context);
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 829686712ca1..a39c722c177f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -317,3 +317,21 @@ func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32
   // CHECK: return %arg0, %arg1
   return %2#0, %2#1 : tensor<1xf32>, tensor<1xi32>
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_transpose
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
+func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
+  %0 = constant dense<[1, 2, 0]> : tensor<3xi32>
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3, 1]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
+  // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
+  // CHECK:   linalg.yield [[ARG1]]
+  // CHECK: }
+  %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
+  return
+}


        


More information about the Mlir-commits mailing list