[Mlir-commits] [mlir] f3f25ff - [mlir][linalg] Fix result type in FoldSourceTensorCast

Matthias Springer llvmlistbot at llvm.org
Fri Sep 24 00:47:31 PDT 2021


Author: Matthias Springer
Date: 2021-09-24T16:47:18+09:00
New Revision: f3f25ffc04c0cbcc9a9bfc1b32b61750e8934ea8

URL: https://github.com/llvm/llvm-project/commit/f3f25ffc04c0cbcc9a9bfc1b32b61750e8934ea8
DIFF: https://github.com/llvm/llvm-project/commit/f3f25ffc04c0cbcc9a9bfc1b32b61750e8934ea8.diff

LOG: [mlir][linalg] Fix result type in FoldSourceTensorCast

* Do not discard static result type information that cannot be inferred from lower/upper padding.
* Add optional argument to `PadTensorOp::inferResultType` for specifying known result dimensions.

Differential Revision: https://reviews.llvm.org/D110380

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 4c82eafc9c973..dd568ba367067 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -226,10 +226,14 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     }
 
     // Infer the shape of the result tensor given the type of the source tensor
-    // and paddings.
-    static RankedTensorType inferResultType(RankedTensorType sourceType,
+    // and paddings. Known result dimensions that cannot necessarily be inferred
+    // from low/high padding sizes can be optionally specified. Those will be
+    // considered when computing the result type.
+    static RankedTensorType inferResultType(
+                                RankedTensorType sourceType,
                                 ArrayRef<int64_t> staticLow,
-                                ArrayRef<int64_t> staticHigh);
+                                ArrayRef<int64_t> staticHigh,
+                                ArrayRef<int64_t> resultShape = {});
 
     // Return a PadTensorOp that pads `source` to `type` size where the static
     // sizes are assumed to be greater than the dynamic sizes. The op performs

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b3eeaabc780ed..75e4a1c91bcda 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1055,24 +1055,31 @@ static LogicalResult verify(PadTensorOp op) {
 
 RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
                                               ArrayRef<int64_t> staticLow,
-                                              ArrayRef<int64_t> staticHigh) {
+                                              ArrayRef<int64_t> staticHigh,
+                                              ArrayRef<int64_t> resultShape) {
   unsigned rank = sourceType.getRank();
   assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
   assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
+  assert((resultShape.empty() || resultShape.size() == rank) &&
+         "unexpected resultShape size mismatch");
 
-  SmallVector<int64_t, 4> resultShape;
+  SmallVector<int64_t, 4> inferredShape;
   for (auto i : llvm::seq<unsigned>(0, rank)) {
     if (sourceType.isDynamicDim(i) ||
         staticLow[i] == ShapedType::kDynamicSize ||
         staticHigh[i] == ShapedType::kDynamicSize) {
-      resultShape.push_back(ShapedType::kDynamicSize);
+      inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
+                                                  : resultShape[i]);
     } else {
       int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
-      resultShape.push_back(size);
+      assert((resultShape.empty() || size == resultShape[i] ||
+              resultShape[i] == ShapedType::kDynamicSize) &&
+             "mismatch between inferred shape and result shape");
+      inferredShape.push_back(size);
     }
   }
 
-  return RankedTensorType::get(resultShape, sourceType.getElementType());
+  return RankedTensorType::get(inferredShape, sourceType.getElementType());
 }
 
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
@@ -1454,7 +1461,8 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
     auto newResultType = PadTensorOp::inferResultType(
         castOp.source().getType().cast<RankedTensorType>(),
         extractFromI64ArrayAttr(padTensorOp.static_low()),
-        extractFromI64ArrayAttr(padTensorOp.static_high()));
+        extractFromI64ArrayAttr(padTensorOp.static_high()),
+        padTensorOp.getResultType().getShape());
 
     if (newResultType == padTensorOp.getResultType()) {
       rewriter.updateRootInPlace(padTensorOp, [&]() {

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 3d434c2d6ebc0..fce08a1e04dca 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -629,7 +629,8 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 }
 
 // -----
-// CHECK-LABEL:   func @pad_tensor_after_cast_
diff ernt_shape(
+
+// CHECK-LABEL:   func @pad_tensor_after_cast_
diff erent_shape(
 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
 // CHECK:           %[[CST:.*]] = constant 0.000000e+00 : f32
 // CHECK:           %[[PADDED:.*]] = linalg.pad_tensor %[[INPUT]]
@@ -641,7 +642,7 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 // CHECK-SAME:         tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
 // CHECK:           return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
 // CHECK:         }
-func @pad_tensor_after_cast_
diff ernt_shape(%arg0: tensor<?x64x?x?xf32>)
+func @pad_tensor_after_cast_
diff erent_shape(%arg0: tensor<?x64x?x?xf32>)
     -> tensor<?x?x?x?xf32> {
   %cst = constant 0.000000e+00 : f32
   %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
@@ -653,6 +654,7 @@ func @pad_tensor_after_cast_
diff ernt_shape(%arg0: tensor<?x64x?x?xf32>)
 }
 
 // -----
+
 // CHECK-LABEL:   func @pad_tensor_after_cast_same_shape(
 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
 // CHECK-SAME:      %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
@@ -676,6 +678,24 @@ func @pad_tensor_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : i
 }
 
 // -----
+
+// CHECK-LABEL: func @pad_tensor_of_cast(
+// CHECK-NOT:     tensor.cast
+// CHECK:         linalg.pad_tensor
+// CHECK:         tensor<8x?xf32> to tensor<8x32xf32>
+func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
+  %c0 = constant 0 : index
+  %cst = constant 0.000000e+00 : f32
+  %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
+  %1 = linalg.pad_tensor %0 low[%c0, %c0] high[%c0, %s]  {
+  ^bb0(%arg9: index, %arg10: index):  // no predecessors
+    linalg.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<8x32xf32>
+  return %1 : tensor<8x32xf32>
+}
+
+// -----
+
 func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
     %arg3 : index) -> tensor<?x?xf32> {
   %c0 = constant 0 : index


        


More information about the Mlir-commits mailing list