[Mlir-commits] [mlir] b2826c0 - [mlir][NFC] Move offsets/sizes/strides helper to dialect utils and interface header
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 31 05:58:21 PDT 2023
Author: Matthias Springer
Date: 2023-07-31T14:53:14+02:00
New Revision: b2826c020975a77cb24f1f69a5460e5bf136b837
URL: https://github.com/llvm/llvm-project/commit/b2826c020975a77cb24f1f69a5460e5bf136b837
DIFF: https://github.com/llvm/llvm-project/commit/b2826c020975a77cb24f1f69a5460e5bf136b837.diff
LOG: [mlir][NFC] Move offsets/sizes/strides helper to dialect utils and interface header
* Move `foldDynamicIndexList` to `DialectUtils` and simplify function.
* Move `OpWithOffsetSizesAndStridesConstantArgumentFolder` to `ViewLikeInterface` and add documentation.
Differential Revision: https://reviews.llvm.org/D156581
Added:
Modified:
mlir/include/mlir/Dialect/Arith/Utils/Utils.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index d0dd5de078c81f..62a84ee7903d76 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -26,50 +26,9 @@ namespace mlir {
/// Matches a ConstantIndexOp.
detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
-/// Returns `success` when any of the elements in `ofrs` was produced by
-/// arith::ConstantIndexOp. In that case the constant attribute replaces the
-/// Value. Returns `failure` when no folding happened.
-LogicalResult foldDynamicIndexList(Builder &b,
- SmallVectorImpl<OpFoldResult> &ofrs);
-
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
ArrayRef<int64_t> shape);
-/// Pattern to rewrite a subview op with constant arguments.
-template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>
-class OpWithOffsetSizesAndStridesConstantArgumentFolder final
- : public OpRewritePattern<OpType> {
-public:
- using OpRewritePattern<OpType>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpType op,
- PatternRewriter &rewriter) const override {
- SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
- SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
- SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
-
- // No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) &&
- failed(foldDynamicIndexList(rewriter, mixedSizes)) &&
- failed(foldDynamicIndexList(rewriter, mixedStrides)))
- return failure();
-
- // Create the new op in canonical form.
- ResultTypeFunc resultTypeFunc;
- auto resultType =
- resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
- if (!resultType)
- return failure();
- auto newOp =
- rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
- mixedOffsets, mixedSizes, mixedStrides);
- CastOpFunc func;
- func(rewriter, op, newOp);
-
- return success();
- }
-};
-
/// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
/// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
/// Other attribute types are not supported.
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 91c43a81feea44..23a366036b9dd6 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -139,6 +139,11 @@ SmallVector<int64_t>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);
+/// Returns "success" when any of the elements in `ofrs` is a constant value. In
+/// that case the value is replaced by an attribute. Returns "failure" when no
+/// folding happened.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+
/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 65ef514908d181..cbc034ca9958a4 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -39,6 +40,47 @@ bool sameOffsetsSizesAndStrides(
namespace mlir {
+/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
+/// constant arguments. This pattern assumes that the op has a suitable builder
+/// that takes a result type, a "source" operand and mixed offsets, sizes and
+/// strides.
+///
+/// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn`
+/// returns the new result type of the op, based on the new offsets, sizes and
+/// strides. `CastOpFunc` is used to generate a cast op if the result type of
+/// the op has changed.
+template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
+class OpWithOffsetSizesAndStridesConstantArgumentFolder final
+ : public OpRewritePattern<OpType> {
+public:
+ using OpRewritePattern<OpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpType op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
+ SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
+ SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
+
+ // No constant operands were folded, just return;
+ if (failed(foldDynamicIndexList(mixedOffsets)) &&
+ failed(foldDynamicIndexList(mixedSizes)) &&
+ failed(foldDynamicIndexList(mixedStrides)))
+ return failure();
+
+ // Create the new op in canonical form.
+ auto resultType =
+ ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
+ if (!resultType)
+ return failure();
+ auto newOp =
+ rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
+ mixedOffsets, mixedSizes, mixedStrides);
+ CastOpFunc()(rewriter, op, newOp);
+
+ return success();
+ }
+};
+
/// Printer hook for custom directive in assemblyFormat.
///
/// custom<DynamicIndexList>($values, $integers)
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index d3e61030d8183a..d5d337a6aa35ee 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -25,25 +25,6 @@ detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
return detail::op_matcher<arith::ConstantIndexOp>();
}
-// Returns `success` when any of the elements in `ofrs` was produced by
-// arith::ConstantIndexOp. In that case the constant attribute replaces the
-// Value. Returns `failure` when no folding happened.
-LogicalResult mlir::foldDynamicIndexList(Builder &b,
- SmallVectorImpl<OpFoldResult> &ofrs) {
- bool valuesChanged = false;
- for (OpFoldResult &ofr : ofrs) {
- if (ofr.is<Attribute>())
- continue;
- // Newly static, move from Value to constant.
- if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
- .getDefiningOp<arith::ConstantIndexOp>()) {
- ofr = b.getIndexAttr(cstOp.value());
- valuesChanged = true;
- }
- }
- return success(valuesChanged);
-}
-
llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
ArrayRef<int64_t> shape) {
llvm::SmallBitVector dimsToProject(shape.size());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index aaa5e39cd2f3d2..cfa09b5d4e11de 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1513,9 +1513,9 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
- if (failed(foldDynamicIndexList(rewriter, mixedLowerBound)) &&
- failed(foldDynamicIndexList(rewriter, mixedUpperBound)) &&
- failed(foldDynamicIndexList(rewriter, mixedStep)))
+ if (failed(foldDynamicIndexList(mixedLowerBound)) &&
+ failed(foldDynamicIndexList(mixedUpperBound)) &&
+ failed(foldDynamicIndexList(mixedStep)))
return failure();
rewriter.updateRootInPlace(op, [&]() {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8fc3494f2d2e83..acd6a7271bce41 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2317,9 +2317,9 @@ class InsertSliceOpConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
// No constant operands were folded, just return;
- if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) &&
- failed(foldDynamicIndexList(rewriter, mixedSizes)) &&
- failed(foldDynamicIndexList(rewriter, mixedStrides)))
+ if (failed(foldDynamicIndexList(mixedOffsets)) &&
+ failed(foldDynamicIndexList(mixedSizes)) &&
+ failed(foldDynamicIndexList(mixedStrides)))
return failure();
// Create the new op in canonical form.
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 2e0bafb4fc6545..8a4ccc990331a7 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,4 +256,18 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
}
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+ bool valuesChanged = false;
+ for (OpFoldResult &ofr : ofrs) {
+ if (ofr.is<Attribute>())
+ continue;
+ Attribute attr;
+ if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+ ofr = attr;
+ valuesChanged = true;
+ }
+ }
+ return success(valuesChanged);
+}
+
} // namespace mlir
More information about the Mlir-commits
mailing list