[Mlir-commits] [mlir] cb40d52 - [mlir][Linalg] Avoid using `tensor.cast` by default while folding `fill` with `pad`.
Mahesh Ravishankar
llvmlistbot at llvm.org
Fri Nov 11 15:17:19 PST 2022
Author: Mahesh Ravishankar
Date: 2022-11-11T23:17:07Z
New Revision: cb40d5291e6113fddcfa5fc3861b7320f78dbfe8
URL: https://github.com/llvm/llvm-project/commit/cb40d5291e6113fddcfa5fc3861b7320f78dbfe8
DIFF: https://github.com/llvm/llvm-project/commit/cb40d5291e6113fddcfa5fc3861b7320f78dbfe8.diff
LOG: [mlir][Linalg] Avoid using `tensor.cast` by default while folding `fill` with `pad`.
This is unnecessary if the generated operation type already matches
the type of the replaced value. Also use `OpFoldResult` to reduce the
number of cases the casts are needed.
Reviewed By: springerm, hanchung, antiagainst
Differential Revision: https://reviews.llvm.org/D137479
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5ebb81bed132d..32c8dd65678bf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -496,17 +496,20 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
return rewriter.notifyMatchFailure(
padOp, "failed to reify tensor.pad op result shape");
- auto oldResultType = padOp.getResultType();
- SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
- ShapedType::kDynamicSize);
+ SmallVector<OpFoldResult> newShape =
+ getAsOpFoldResult(reifiedShape.front());
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- padOp.getLoc(), staticShape, oldResultType.getElementType(),
- reifiedShape.front());
- auto newFillOp = rewriter.create<FillOp>(
- fillOp.getLoc(), ValueRange{padValue}, ValueRange{emptyTensor});
- rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
- newFillOp.result());
-
+ padOp.getLoc(), newShape, padOp.getResultType().getElementType());
+ Value replacement =
+ rewriter
+ .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
+ ValueRange{emptyTensor})
+ .getResult(0);
+ if (replacement.getType() != padOp.getResultType()) {
+ replacement = rewriter.create<tensor::CastOp>(
+ fillOp.getLoc(), padOp.getResultType(), replacement);
+ }
+ rewriter.replaceOp(padOp, replacement);
return success();
}
};
More information about the Mlir-commits
mailing list