[Mlir-commits] [mlir] 4157a07 - [mlir][tosa] Add tosa.pad to linalg.pad operation
Rob Suderman
llvmlistbot at llvm.org
Tue Mar 23 14:18:09 PDT 2021
Author: Rob Suderman
Date: 2021-03-23T14:15:48-07:00
New Revision: 4157a079afbf7fa5c3ce3ac0e9f4541f89188ae2
URL: https://github.com/llvm/llvm-project/commit/4157a079afbf7fa5c3ce3ac0e9f4541f89188ae2
DIFF: https://github.com/llvm/llvm-project/commit/4157a079afbf7fa5c3ce3ac0e9f4541f89188ae2.diff
LOG: [mlir][tosa] Add tosa.pad to linalg.pad operation
Lowers from tosa's pad op to the linalg equivalent for floating,
integer, and quantized values.
Differential Revision: https://reviews.llvm.org/D98990
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index a44621ec6033..7cc1721bb0fb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
MLIRMath
MLIRMemRef
MLIRPass
+ MLIRTensor
MLIRTosa
MLIRTosaTransforms
MLIRSupport
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 12e9e694760c..a4b6f826feb6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
@@ -1155,7 +1156,79 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, resultTy, genericOp.getResult(0),
rewriter.getI64ArrayAttr(resultTy.getShape()));
+ return success();
+ }
+};
+
+class PadConverter : public OpRewritePattern<tosa::PadOp> {
+public:
+ using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::PadOp padOp,
+ PatternRewriter &rewriter) const final {
+ auto loc = padOp.getLoc();
+ auto input = padOp.input1();
+ auto padding = padOp.padding();
+
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ ShapedType paddingTy = padding.getType().cast<ShapedType>();
+ Type elementTy = inputTy.getElementType();
+ int64_t rank = inputTy.getRank();
+
+ if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ padOp,
+ "Pad converter requires static shaped input / padding values.");
+ }
+
+ Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
+ Value highIndex =
+ rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
+
+ SmallVector<OpFoldResult, 3> lowValues;
+ SmallVector<OpFoldResult, 3> highValues;
+
+ lowValues.reserve(rank);
+ highValues.reserve(rank);
+
+ for (int i = 0; i < rank; i++) {
+ Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
+ Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
+ loc, padding, ValueRange({inputIndex, lowIndex}));
+ Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
+ loc, padding, ValueRange({inputIndex, highIndex}));
+
+ lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
+ lowVal);
+ highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
+ highVal);
+
+ lowValues.push_back(lowVal);
+ highValues.push_back(highVal);
+ }
+
+ Attribute constantAttr;
+ if (elementTy.isa<FloatType>())
+ constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+ else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
+ constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+ else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
+ auto value = padOp.quantization_info().getValue().input_zp().getValue();
+ constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
+ }
+
+ if (!constantAttr) {
+ return rewriter.notifyMatchFailure(
+ padOp,
+ "tosa.pad to linalg lowering encountered an unknown element type");
+ }
+
+ Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
+
+ auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
+ padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
+ rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};
@@ -1187,7 +1260,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
IdentityNConverter<tosa::IdentityOp>,
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
- ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
- RescaleConverter, ReverseConverter, TileConverter, TransposeConverter,
- MatMulConverter, FullyConnectedConverter>(patterns->getContext());
+ ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, PadConverter,
+ ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
+ TransposeConverter, MatMulConverter, FullyConnectedConverter>(
+ patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 5c0dbc50c2d7..baf9e575a473 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -33,14 +34,15 @@ struct TosaToLinalgOnTensors
public:
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect, math::MathDialect,
- memref::MemRefDialect, StandardOpsDialect>();
+ memref::MemRefDialect, StandardOpsDialect,
+ tensor::TensorDialect>();
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
- StandardOpsDialect>();
+ StandardOpsDialect, tensor::TensorDialect>();
target.addIllegalDialect<tosa::TosaDialect>();
// Not every TOSA op can be legalized to linalg.
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 018e9e4d7e54..39a4f4122924 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -702,3 +702,46 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: ten
%0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>)
return %0 : tensor<5x6xf32>
}
+
+// -----
+
+func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
+ %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK: [[INDEX0:%.+]] = constant 0 : index
+ // CHECK: [[INDEX1:%.+]] = constant 1 : index
+ // CHECK: [[ROW0:%.+]] = constant 0 : index
+ // CHECK: [[LOW0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX0]]]
+ // CHECK: [[HIGH0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX1]]]
+ // CHECK: [[LOW0_IDX:%.+]] = index_cast %0
+ // CHECK: [[HIGH0_IDX:%.+]] = index_cast %1
+ // CHECK: [[ROW1:%.+]] = constant 1 : index
+ // CHECK: [[LOW1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c0]
+ // CHECK: [[HIGH1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c1]
+ // CHECK: [[LOW1_IDX:%.+]] = index_cast [[LOW1]]
+ // CHECK: [[HIGH1_IDX:%.+]] = index_cast [[HIGH1]]
+ // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
+ // CHECK: %8 = linalg.pad_tensor %arg0 low{{\[}}[[LOW0_IDX]], [[LOW1_IDX]]] high{{\[}}[[HIGH0_IDX]], [[HIGH1_IDX]]] {
+ // CHECK: ^bb0(%arg1: index, %arg2: index): // no predecessors
+ // CHECK: linalg.yield [[CST]]
+ // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xf32>, tensor<2x2xi32>) -> (tensor<4x9xf32>)
+ return %1 : tensor<4x9xf32>
+}
+
+func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+ %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK: [[CST:%.+]] = constant 0 : i32
+ // CHECK: linalg.pad_tensor
+ // CHECK: linalg.yield [[CST]]
+ %1 = "tosa.pad"(%arg0, %0) : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ return %1 : tensor<4x9xi32>
+}
+
+func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+ %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+ // CHECK: [[CST:%.+]] = constant 42 : i32
+ // CHECK: linalg.pad_tensor
+ // CHECK: linalg.yield [[CST]]
+ %1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>) -> (tensor<4x9xi32>)
+ return %1 : tensor<4x9xi32>
+}
More information about the Mlir-commits
mailing list