[Mlir-commits] [mlir] be58d48 - [mlir] Canonicalize dynamic tensor.pad ops with constant inputs
George Petterson
llvmlistbot at llvm.org
Fri Feb 3 13:43:56 PST 2023
Author: George Petterson
Date: 2023-02-03T16:43:45-05:00
New Revision: be58d484cb315028049255d9bb29ce2c7bffc983
URL: https://github.com/llvm/llvm-project/commit/be58d484cb315028049255d9bb29ce2c7bffc983
DIFF: https://github.com/llvm/llvm-project/commit/be58d484cb315028049255d9bb29ce2c7bffc983.diff
LOG: [mlir] Canonicalize dynamic tensor.pad ops with constant inputs
This commit adds a canonicalization pattern for tensor.pad which changes the output type to static at each dimension where the input shape is static and the high and low operands are constants. This corrects an issue arising in Torch-MLIR where pad ops would sometimes introduce dynamic shapes unnecessarily.
Reviewed By: raikonenfnu
Differential Revision: https://reviews.llvm.org/D143135
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7960b64fd7151..d35895a167558 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2846,12 +2846,105 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
}
};
+struct FoldStaticPadding : public OpRewritePattern<PadOp> {
+ using OpRewritePattern<PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadOp padTensorOp,
+ PatternRewriter &rewriter) const override {
+ Value input = padTensorOp.getSource();
+ if (!input.getType().isa<RankedTensorType>())
+ return failure();
+ auto inputDims = input.getType().cast<RankedTensorType>().getShape();
+ auto inputRank = inputDims.size();
+
+ if (!padTensorOp.getResult().getType().isa<RankedTensorType>())
+ return failure();
+ auto outputDims =
+ padTensorOp.getResult().getType().cast<RankedTensorType>().getShape();
+
+ // Extract the static info from the high and low operands.
+ SmallVector<int64_t> constOperandsLow;
+ for (auto operand : padTensorOp.getLow()) {
+ APSInt intOp;
+ if (!matchPattern(operand, m_ConstantInt(&intOp))) {
+ constOperandsLow.push_back(ShapedType::kDynamic);
+ continue;
+ }
+ constOperandsLow.push_back(intOp.getExtValue());
+ }
+ SmallVector<int64_t> constOperandsHigh;
+ for (auto operand : padTensorOp.getHigh()) {
+ APSInt intOp;
+ if (!matchPattern(operand, m_ConstantInt(&intOp))) {
+ constOperandsHigh.push_back(ShapedType::kDynamic);
+ continue;
+ }
+ constOperandsHigh.push_back(intOp.getExtValue());
+ }
+
+ SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
+ SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
+
+ // Verify the op is well-formed.
+ if (inputDims.size() != outputDims.size() ||
+ inputDims.size() != constLow.size() ||
+ inputDims.size() != constHigh.size())
+ return failure();
+
+ auto lowCount = 0;
+ auto highCount = 0;
+ for (size_t i = 0; i < inputRank; i++) {
+ if (constLow[i] == ShapedType::kDynamic)
+ constLow[i] = constOperandsLow[lowCount++];
+ if (constHigh[i] == ShapedType::kDynamic)
+ constHigh[i] = constOperandsHigh[highCount++];
+ }
+
+ auto staticLow = ArrayRef<int64_t>(constLow);
+ auto staticHigh = ArrayRef<int64_t>(constHigh);
+
+ // Calculate the output sizes with the static information.
+ SmallVector<int64_t> newOutDims;
+ for (size_t i = 0; i < inputRank; i++) {
+ if (outputDims[i] == ShapedType::kDynamic) {
+ newOutDims.push_back(
+ (staticLow[i] == ShapedType::kDynamic ||
+ staticHigh[i] == ShapedType::kDynamic ||
+ inputDims[i] == ShapedType::kDynamic
+ ? ShapedType::kDynamic
+ : inputDims[i] + staticLow[i] + staticHigh[i]));
+ } else {
+ newOutDims.push_back(outputDims[i]);
+ }
+ }
+
+ if (SmallVector<int64_t>(outputDims) == newOutDims ||
+ llvm::all_of(newOutDims,
+ [&](int64_t x) { return x == ShapedType::kDynamic; }))
+ return failure();
+
+ // Rewrite the op using the new static type.
+ auto newResultType = RankedTensorType::get(
+ newOutDims, padTensorOp.getType().getElementType());
+ auto newOp = rewriter.create<PadOp>(
+ padTensorOp->getLoc(), newResultType, input, padTensorOp.getLow(),
+ padTensorOp.getHigh(), staticLow, staticHigh, padTensorOp.getNofold());
+
+ IRMapping mapper;
+ padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, newResultType,
+ newOp);
+
+ return success();
+ }
+};
+
} // namespace
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
- FoldOrthogonalPaddings>(context);
+ FoldOrthogonalPaddings, FoldStaticPadding>(context);
}
/// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f4706fc439b9e..6b7280453f975 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1111,6 +1111,29 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
// -----
+// CHECK-LABEL: func @pad_fold_static(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDING:.*]] = arith.constant 4 : index
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
+// CHECK-SAME: low[0, 4, 1, 1] high[0, 4, 1, 1] {
+// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
+func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>)
+ -> tensor<?xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %padding = arith.constant 4 : index
+ %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %cst: f32
+ } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
+ %result = tensor.collapse_shape %padded [[0, 1, 2, 3]] : tensor<?x?x?x?xf32> into tensor<?xf32>
+ return %result : tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @pad_nofold_same_static_shape(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
// CHECK: %[[PAD:.*]] = tensor.pad
More information about the Mlir-commits
mailing list