[Mlir-commits] [mlir] 4157a07 - [mlir][tosa] Add tosa.pad to linalg.pad operation

Rob Suderman llvmlistbot at llvm.org
Tue Mar 23 14:18:09 PDT 2021


Author: Rob Suderman
Date: 2021-03-23T14:15:48-07:00
New Revision: 4157a079afbf7fa5c3ce3ac0e9f4541f89188ae2

URL: https://github.com/llvm/llvm-project/commit/4157a079afbf7fa5c3ce3ac0e9f4541f89188ae2
DIFF: https://github.com/llvm/llvm-project/commit/4157a079afbf7fa5c3ce3ac0e9f4541f89188ae2.diff

LOG: [mlir][tosa] Add tosa.pad to linalg.pad operation

Lowers from tosa's pad op to the linalg equivalent for floating,
integer, and quantized values.

Differential Revision: https://reviews.llvm.org/D98990

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
index a44621ec6033..7cc1721bb0fb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
+++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
   MLIRMath
   MLIRMemRef
   MLIRPass
+  MLIRTensor
   MLIRTosa
   MLIRTosaTransforms
   MLIRSupport

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 12e9e694760c..a4b6f826feb6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1155,7 +1156,79 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
         op, resultTy, genericOp.getResult(0),
         rewriter.getI64ArrayAttr(resultTy.getShape()));
+    return success();
+  }
+};
+
+class PadConverter : public OpRewritePattern<tosa::PadOp> {
+public:
+  using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::PadOp padOp,
+                                PatternRewriter &rewriter) const final {
+    auto loc = padOp.getLoc();
+    auto input = padOp.input1();
+    auto padding = padOp.padding();
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    ShapedType paddingTy = padding.getType().cast<ShapedType>();
+    Type elementTy = inputTy.getElementType();
+    int64_t rank = inputTy.getRank();
+
+    if (!inputTy.hasStaticShape() || !paddingTy.hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          padOp,
+          "Pad converter requires static shaped input / padding values.");
+    }
+
+    Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value highIndex =
+        rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
+
+    SmallVector<OpFoldResult, 3> lowValues;
+    SmallVector<OpFoldResult, 3> highValues;
+
+    lowValues.reserve(rank);
+    highValues.reserve(rank);
+
+    for (int i = 0; i < rank; i++) {
+      Value inputIndex = rewriter.createOrFold<ConstantIndexOp>(loc, i);
+      Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
+          loc, padding, ValueRange({inputIndex, lowIndex}));
+      Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
+          loc, padding, ValueRange({inputIndex, highIndex}));
+
+      lowVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
+                                                  lowVal);
+      highVal = rewriter.createOrFold<IndexCastOp>(loc, rewriter.getIndexType(),
+                                                   highVal);
+
+      lowValues.push_back(lowVal);
+      highValues.push_back(highVal);
+    }
+
+    Attribute constantAttr;
+    if (elementTy.isa<FloatType>())
+      constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+    else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
+      constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+    else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
+      auto value = padOp.quantization_info().getValue().input_zp().getValue();
+      constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
+    }
+
+    if (!constantAttr) {
+      return rewriter.notifyMatchFailure(
+          padOp,
+          "tosa.pad to linalg lowering encountered an unknown element type");
+    }
+
+    Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
+
+    auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
+        padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
 
+    rewriter.replaceOp(padOp, newPadOp.getResult());
     return success();
   }
 };
@@ -1187,7 +1260,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       IdentityNConverter<tosa::IdentityOp>,
       IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
       ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
-      ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
-      RescaleConverter, ReverseConverter, TileConverter, TransposeConverter,
-      MatMulConverter, FullyConnectedConverter>(patterns->getContext());
+      ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, PadConverter,
+      ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter,
+      TransposeConverter, MatMulConverter, FullyConnectedConverter>(
+        patterns->getContext());
 }

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 5c0dbc50c2d7..baf9e575a473 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -33,14 +34,15 @@ struct TosaToLinalgOnTensors
 public:
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<linalg::LinalgDialect, math::MathDialect,
-                    memref::MemRefDialect, StandardOpsDialect>();
+                    memref::MemRefDialect, StandardOpsDialect,
+                    tensor::TensorDialect>();
   }
 
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
-                           StandardOpsDialect>();
+                           StandardOpsDialect, tensor::TensorDialect>();
     target.addIllegalDialect<tosa::TosaDialect>();
 
     // Not every TOSA op can be legalized to linalg.

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 018e9e4d7e54..39a4f4122924 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -702,3 +702,46 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: ten
   %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>)  -> (tensor<5x6xf32>)
   return %0 : tensor<5x6xf32>
 }
+
+// -----
+
+func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
+  %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: [[INDEX0:%.+]] = constant 0 : index
+  // CHECK: [[INDEX1:%.+]] = constant 1 : index
+  // CHECK: [[ROW0:%.+]] = constant 0 : index
+  // CHECK: [[LOW0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX0]]]
+  // CHECK: [[HIGH0:%.+]] = tensor.extract %cst{{\[}}[[ROW0]], [[INDEX1]]]
+  // CHECK: [[LOW0_IDX:%.+]] = index_cast %0
+  // CHECK: [[HIGH0_IDX:%.+]] = index_cast %1
+  // CHECK: [[ROW1:%.+]] = constant 1 : index
+  // CHECK: [[LOW1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c0]
+  // CHECK: [[HIGH1:%.+]] = tensor.extract %cst{{\[}}%c1_1, %c1]
+  // CHECK: [[LOW1_IDX:%.+]] = index_cast [[LOW1]]
+  // CHECK: [[HIGH1_IDX:%.+]] = index_cast [[HIGH1]]
+  // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
+  // CHECK: %8 = linalg.pad_tensor %arg0 low{{\[}}[[LOW0_IDX]], [[LOW1_IDX]]] high{{\[}}[[HIGH0_IDX]], [[HIGH1_IDX]]]  {
+  // CHECK: ^bb0(%arg1: index, %arg2: index):  // no predecessors
+  // CHECK:   linalg.yield [[CST]]
+  // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<4x9xf32>)
+  return %1 : tensor<4x9xf32>
+}
+
+func @pad_int(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+  %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: [[CST:%.+]] = constant 0 : i32
+  // CHECK: linalg.pad_tensor
+  // CHECK:   linalg.yield [[CST]]
+  %1 = "tosa.pad"(%arg0, %0)  : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
+  return %1 : tensor<4x9xi32>
+}
+
+func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
+  %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // CHECK: [[CST:%.+]] = constant 42 : i32
+  // CHECK: linalg.pad_tensor
+  // CHECK:   linalg.yield [[CST]]
+  %1 = "tosa.pad"(%arg0, %0) { quantization_info = { input_zp = 42 : i32}} : (tensor<1x2xi32>, tensor<2x2xi32>)  -> (tensor<4x9xi32>)
+  return %1 : tensor<4x9xi32>
+}


        


More information about the Mlir-commits mailing list