[Mlir-commits] [mlir] 087bc20 - [MLIR][TOSA] Lower tosa.transpose to linalg.generic
Rob Suderman
llvmlistbot at llvm.org
Mon Mar 1 11:10:17 PST 2021
Author: Rob Suderman
Date: 2021-03-01T11:09:49-08:00
New Revision: 087bc20fe42f2619fca76818900f73dd2c4a5b94
URL: https://github.com/llvm/llvm-project/commit/087bc20fe42f2619fca76818900f73dd2c4a5b94
DIFF: https://github.com/llvm/llvm-project/commit/087bc20fe42f2619fca76818900f73dd2c4a5b94.diff
LOG: [MLIR][TOSA] Lower tosa.transpose to linalg.generic
Lowers the transpose operation to a generic linalg op when permutations
is a constant value.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D97508
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f6cf7ed709a3..ece9380845b6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -438,6 +439,48 @@ class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, args[0], reassociationMap);
+
+ return success();
+ }
+};
+
+class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
+public:
+ using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::TransposeOp op,
+ PatternRewriter &rewriter) const final {
+ DenseIntElementsAttr perms;
+ if (!matchPattern(op.perms(), m_Constant(&perms))) {
+ return failure();
+ }
+
+ auto resultTy = op.getType().cast<ShapedType>();
+ if (!resultTy.hasStaticShape())
+ return failure();
+
+ SmallVector<AffineExpr, 2> inputExprs;
+ inputExprs.resize(resultTy.getRank());
+ for (auto permutation : llvm::enumerate(perms.getIntValues())) {
+ inputExprs[permutation.value().getZExtValue()] =
+ rewriter.getAffineDimExpr(permutation.index());
+ }
+
+ auto initTensor = rewriter.create<linalg::InitTensorOp>(
+ op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
+ resultTy.getElementType());
+
+ SmallVector<AffineMap, 2> affineMaps = {
+ AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
+ rewriter.getContext()),
+ rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+ rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+ op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
+ getNParallelLoopsAttrs(resultTy.getRank()),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
+ });
return success();
}
};
@@ -478,5 +521,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
IdentityNConverter<tosa::IdentityOp>,
- IdentityNConverter<tosa::IdentityNOp>, ReshapeOpConverter>(context);
+ IdentityNConverter<tosa::IdentityNOp>,
+ ReshapeOpConverter, TransposeConverter>(context);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 829686712ca1..a39c722c177f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -317,3 +317,21 @@ func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<1xf32
// CHECK: return %arg0, %arg1
return %2#0, %2#1 : tensor<1xf32>, tensor<1xi32>
}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_transpose
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<1x2x3xi32>)
+func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
+ %0 = constant dense<[1, 2, 0]> : tensor<3xi32>
+ // CHECK: [[INIT:%.+]] = linalg.init_tensor [2, 3, 1]
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins([[ARG0]] : tensor<1x2x3xi32>) outs([[OUT:%.+]] : tensor<2x3x1xi32>)
+ // CHECK: ^bb0([[ARG1:%.+]]: i32, [[ARG2:%.+]]: i32)
+ // CHECK: linalg.yield [[ARG1]]
+ // CHECK: }
+ %1 = "tosa.transpose"(%arg0, %0) : (tensor<1x2x3xi32>, tensor<3xi32>) -> (tensor<2x3x1xi32>)
+ return
+}
More information about the Mlir-commits
mailing list