[Mlir-commits] [mlir] 860d381 - [mlir][tosa] Add lowering for tosa.pad with explicit value

Rob Suderman llvmlistbot at llvm.org
Wed Nov 10 14:22:54 PST 2021


Author: Rob Suderman
Date: 2021-11-10T14:15:20-08:00
New Revision: 860d3811a9b2f3df0ac093d87832056fd7a19b87

URL: https://github.com/llvm/llvm-project/commit/860d3811a9b2f3df0ac093d87832056fd7a19b87
DIFF: https://github.com/llvm/llvm-project/commit/860d3811a9b2f3df0ac093d87832056fd7a19b87.diff

LOG: [mlir][tosa] Add lowering for tosa.pad with explicit value

New TOSA pad operation can support explicitly specifying the pad value. Added
lowering to linalg that uses the explicit value.

Differential Revision: https://reviews.llvm.org/D113515

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f66e92414b796..54165266538c2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2381,20 +2381,30 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
           "Pad converter requires static shaped input / padding values.");
     }
 
-    Attribute constantAttr;
-    if (elementTy.isa<FloatType>())
-      constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-    else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
-      constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-    else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
-      auto value = padOp.quantization_info().getValue().input_zp().getValue();
-      constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
+    // Setup the default constantAttr.
+
+    Value padConstant;
+
+    if (padOp.pad_const()) {
+      padConstant = rewriter.createOrFold<tensor::ExtractOp>(
+          loc, padOp.pad_const(), ValueRange({}));
+    } else {
+      Attribute constantAttr;
+      if (elementTy.isa<FloatType>())
+        constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+      else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
+        constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+      else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
+        auto value = padOp.quantization_info().getValue().input_zp().getValue();
+        constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
+      }
+      if (constantAttr)
+        padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
     }
 
-    if (!constantAttr) {
+    if (!padConstant) {
       return rewriter.notifyMatchFailure(
-          padOp,
-          "tosa.pad to linalg lowering encountered an unknown element type");
+          padOp, "tosa.pad was unable to determine the pad constant value.");
     }
 
     Value lowIndex =
@@ -2424,10 +2434,8 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
       highValues.push_back(highVal);
     }
 
-    Value constant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
-
     auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
-        padOp.getType(), input, constant, lowValues, highValues,
+        padOp.getType(), input, padConstant, lowValues, highValues,
         /*nofold=*/false, loc, rewriter);
 
     rewriter.replaceOp(padOp, newPadOp.getResult());

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 51d355744e128..c7ddb6bdae5c7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1187,11 +1187,11 @@ func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2:
 func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
   // TODO: Output contains multiple "arith.constant 1 : index".
-  // CHECK: [[INDEX1:%.+]] = arith.constant 1 : index
-  // CHECK: [[INDEX2:%.+]] = arith.constant 2 : index
-  // CHECK: [[INDEX3:%.+]] = arith.constant 3 : index
-  // CHECK: [[INDEX4:%.+]] = arith.constant 4 : index
-  // CHECK: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000000e+00 : f32
   // CHECK: linalg.pad_tensor %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
   // CHECK: ^bb0(%arg1: index, %arg2: index):  // no predecessors
   // CHECK:   linalg.yield [[CST]]
@@ -1220,6 +1220,25 @@ func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) {
 
 // -----
 
+func @pad_float_explicit(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
+  %0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
+  // TODO: Output contains multiple "arith.constant 1 : index".
+  // CHECK-DAG: [[INDEX1:%.+]] = arith.constant 1 : index
+  // CHECK-DAG: [[INDEX2:%.+]] = arith.constant 2 : index
+  // CHECK-DAG: [[INDEX3:%.+]] = arith.constant 3 : index
+  // CHECK-DAG: [[INDEX4:%.+]] = arith.constant 4 : index
+  // CHECK-DAG: [[CST:%.+]] = arith.constant 4.200000e+01 : f32
+  // CHECK: linalg.pad_tensor %arg0 low{{\[}}%{{.*}}, [[INDEX3]]] high{{\[}}[[INDEX2]], [[INDEX4]]]  {
+  // CHECK: ^bb0(%arg1: index, %arg2: index):  // no predecessors
+  // CHECK:   linalg.yield [[CST]]
+  // CHECK: } : tensor<1x2xf32> to tensor<4x9xf32>
+  %1 = arith.constant dense<42.0> : tensor<f32>
+  %2 = "tosa.pad"(%arg0, %0, %1)  : (tensor<1x2xf32>, tensor<2x2xi32>, tensor<f32>)  -> (tensor<4x9xf32>)
+  return %2 : tensor<4x9xf32>
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>


        


More information about the Mlir-commits mailing list