[Mlir-commits] [mlir] be58d48 - [mlir] Canonicalize dynamic tensor.pad ops with constant inputs

George Petterson llvmlistbot at llvm.org
Fri Feb 3 13:43:56 PST 2023


Author: George Petterson
Date: 2023-02-03T16:43:45-05:00
New Revision: be58d484cb315028049255d9bb29ce2c7bffc983

URL: https://github.com/llvm/llvm-project/commit/be58d484cb315028049255d9bb29ce2c7bffc983
DIFF: https://github.com/llvm/llvm-project/commit/be58d484cb315028049255d9bb29ce2c7bffc983.diff

LOG: [mlir] Canonicalize dynamic tensor.pad ops with constant inputs

This commit adds a canonicalization pattern for tensor.pad which changes the output type to static at each dimension where the input shape is static and the high and low operands are constants. This corrects an issue arising in Torch-MLIR where pad ops would sometimes introduce dynamic shapes unnecessarily.

Reviewed By: raikonenfnu

Differential Revision: https://reviews.llvm.org/D143135

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7960b64fd7151..d35895a167558 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2846,12 +2846,105 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
   }
 };
 
+struct FoldStaticPadding : public OpRewritePattern<PadOp> {
+  using OpRewritePattern<PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadOp padTensorOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = padTensorOp.getSource();
+    if (!input.getType().isa<RankedTensorType>())
+      return failure();
+    auto inputDims = input.getType().cast<RankedTensorType>().getShape();
+    auto inputRank = inputDims.size();
+
+    if (!padTensorOp.getResult().getType().isa<RankedTensorType>())
+      return failure();
+    auto outputDims =
+        padTensorOp.getResult().getType().cast<RankedTensorType>().getShape();
+
+    // Extract the static info from the high and low operands.
+    SmallVector<int64_t> constOperandsLow;
+    for (auto operand : padTensorOp.getLow()) {
+      APSInt intOp;
+      if (!matchPattern(operand, m_ConstantInt(&intOp))) {
+        constOperandsLow.push_back(ShapedType::kDynamic);
+        continue;
+      }
+      constOperandsLow.push_back(intOp.getExtValue());
+    }
+    SmallVector<int64_t> constOperandsHigh;
+    for (auto operand : padTensorOp.getHigh()) {
+      APSInt intOp;
+      if (!matchPattern(operand, m_ConstantInt(&intOp))) {
+        constOperandsHigh.push_back(ShapedType::kDynamic);
+        continue;
+      }
+      constOperandsHigh.push_back(intOp.getExtValue());
+    }
+
+    SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
+    SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
+
+    // Verify the op is well-formed.
+    if (inputDims.size() != outputDims.size() ||
+        inputDims.size() != constLow.size() ||
+        inputDims.size() != constHigh.size())
+      return failure();
+
+    auto lowCount = 0;
+    auto highCount = 0;
+    for (size_t i = 0; i < inputRank; i++) {
+      if (constLow[i] == ShapedType::kDynamic)
+        constLow[i] = constOperandsLow[lowCount++];
+      if (constHigh[i] == ShapedType::kDynamic)
+        constHigh[i] = constOperandsHigh[highCount++];
+    }
+
+    auto staticLow = ArrayRef<int64_t>(constLow);
+    auto staticHigh = ArrayRef<int64_t>(constHigh);
+
+    // Calculate the output sizes with the static information.
+    SmallVector<int64_t> newOutDims;
+    for (size_t i = 0; i < inputRank; i++) {
+      if (outputDims[i] == ShapedType::kDynamic) {
+        newOutDims.push_back(
+            (staticLow[i] == ShapedType::kDynamic ||
+                     staticHigh[i] == ShapedType::kDynamic ||
+                     inputDims[i] == ShapedType::kDynamic
+                 ? ShapedType::kDynamic
+                 : inputDims[i] + staticLow[i] + staticHigh[i]));
+      } else {
+        newOutDims.push_back(outputDims[i]);
+      }
+    }
+
+    if (SmallVector<int64_t>(outputDims) == newOutDims ||
+        llvm::all_of(newOutDims,
+                     [&](int64_t x) { return x == ShapedType::kDynamic; }))
+      return failure();
+
+    // Rewrite the op using the new static type.
+    auto newResultType = RankedTensorType::get(
+        newOutDims, padTensorOp.getType().getElementType());
+    auto newOp = rewriter.create<PadOp>(
+        padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
+        padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
+
+    IRMapping mapper;
+    padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, newResultType,
+                                                newOp);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
-              FoldOrthogonalPaddings>(context);
+              FoldOrthogonalPaddings, FoldStaticPadding>(context);
 }
 
 /// Return the padding value of the PadOp if it constant. In this context,

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f4706fc439b9e..6b7280453f975 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1111,6 +1111,29 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
+// CHECK-LABEL:   func @pad_fold_static(
+// CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?xf32> {
+// CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[PADDING:.*]] = arith.constant 4 : index
+// CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
+// CHECK-SAME:        low[0, 4, 1, 1] high[0, 4, 1, 1]  {
+// CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK:             tensor.yield %[[CST]] : f32
+// CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
+func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>)
+    -> tensor<?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %padding = arith.constant 4 : index
+  %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    tensor.yield %cst: f32
+  } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
+  %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor<?x?x?x?xf32> into tensor<?xf32>
+  return %result : tensor<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pad_nofold_same_static_shape(
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
 //       CHECK:   %[[PAD:.*]] = tensor.pad


        


More information about the Mlir-commits mailing list