[Mlir-commits] [mlir] b2826c0 - [mlir][NFC] Move offsets/sizes/strides helper to dialect utils and interface header

Matthias Springer llvmlistbot at llvm.org
Mon Jul 31 05:58:21 PDT 2023


Author: Matthias Springer
Date: 2023-07-31T14:53:14+02:00
New Revision: b2826c020975a77cb24f1f69a5460e5bf136b837

URL: https://github.com/llvm/llvm-project/commit/b2826c020975a77cb24f1f69a5460e5bf136b837
DIFF: https://github.com/llvm/llvm-project/commit/b2826c020975a77cb24f1f69a5460e5bf136b837.diff

LOG: [mlir][NFC] Move offsets/sizes/strides helper to dialect utils and interface header

* Move `foldDynamicIndexList` to `DialectUtils` and simplify function.
* Move `OpWithOffsetSizesAndStridesConstantArgumentFolder` to `ViewLikeInterface` and add documentation.

Differential Revision: https://reviews.llvm.org/D156581

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/Utils/Utils.h
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Arith/Utils/Utils.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index d0dd5de078c81f..62a84ee7903d76 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -26,50 +26,9 @@ namespace mlir {
 /// Matches a ConstantIndexOp.
 detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
 
-/// Returns `success` when any of the elements in `ofrs` was produced by
-/// arith::ConstantIndexOp. In that case the constant attribute replaces the
-/// Value. Returns `failure` when no folding happened.
-LogicalResult foldDynamicIndexList(Builder &b,
-                                   SmallVectorImpl<OpFoldResult> &ofrs);
-
 llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
                                             ArrayRef<int64_t> shape);
 
-/// Pattern to rewrite a subview op with constant arguments.
-template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>
-class OpWithOffsetSizesAndStridesConstantArgumentFolder final
-    : public OpRewritePattern<OpType> {
-public:
-  using OpRewritePattern<OpType>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpType op,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
-    SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
-    SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
-
-    // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) &&
-        failed(foldDynamicIndexList(rewriter, mixedSizes)) &&
-        failed(foldDynamicIndexList(rewriter, mixedStrides)))
-      return failure();
-
-    // Create the new op in canonical form.
-    ResultTypeFunc resultTypeFunc;
-    auto resultType =
-        resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
-    if (!resultType)
-      return failure();
-    auto newOp =
-        rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
-                                mixedOffsets, mixedSizes, mixedStrides);
-    CastOpFunc func;
-    func(rewriter, op, newOp);
-
-    return success();
-  }
-};
-
 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
 /// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
 /// Other attribute types are not supported.

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 91c43a81feea44..23a366036b9dd6 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -139,6 +139,11 @@ SmallVector<int64_t>
 getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
                      llvm::function_ref<bool(Attribute, Attribute)> compare);
 
+/// Returns "success" when any of the elements in `ofrs` is a constant value. In
+/// that case the value is replaced by an attribute. Returns "failure" when no
+/// folding happened.
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs);
+
 /// Return the number of iterations for a loop with a lower bound `lb`, upper
 /// bound `ub` and step `step`.
 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 65ef514908d181..cbc034ca9958a4 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 
@@ -39,6 +40,47 @@ bool sameOffsetsSizesAndStrides(
 
 namespace mlir {
 
+/// Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as
+/// constant arguments. This pattern assumes that the op has a suitable builder
+/// that takes a result type, a "source" operand and mixed offsets, sizes and
+/// strides.
+///
+/// `OpType` is the type of op to which this pattern is applied. `ResultTypeFn`
+/// returns the new result type of the op, based on the new offsets, sizes and
+/// strides. `CastOpFunc` is used to generate a cast op if the result type of
+/// the op has changed.
+template <typename OpType, typename ResultTypeFn, typename CastOpFunc>
+class OpWithOffsetSizesAndStridesConstantArgumentFolder final
+    : public OpRewritePattern<OpType> {
+public:
+  using OpRewritePattern<OpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpType op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
+    SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
+    SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
+
+    // No constant operands were folded, just return;
+    if (failed(foldDynamicIndexList(mixedOffsets)) &&
+        failed(foldDynamicIndexList(mixedSizes)) &&
+        failed(foldDynamicIndexList(mixedStrides)))
+      return failure();
+
+    // Create the new op in canonical form.
+    auto resultType =
+        ResultTypeFn()(op, mixedOffsets, mixedSizes, mixedStrides);
+    if (!resultType)
+      return failure();
+    auto newOp =
+        rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
+                                mixedOffsets, mixedSizes, mixedStrides);
+    CastOpFunc()(rewriter, op, newOp);
+
+    return success();
+  }
+};
+
 /// Printer hook for custom directive in assemblyFormat.
 ///
 ///   custom<DynamicIndexList>($values, $integers)

diff  --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index d3e61030d8183a..d5d337a6aa35ee 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -25,25 +25,6 @@ detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
   return detail::op_matcher<arith::ConstantIndexOp>();
 }
 
-// Returns `success` when any of the elements in `ofrs` was produced by
-// arith::ConstantIndexOp. In that case the constant attribute replaces the
-// Value. Returns `failure` when no folding happened.
-LogicalResult mlir::foldDynamicIndexList(Builder &b,
-                                         SmallVectorImpl<OpFoldResult> &ofrs) {
-  bool valuesChanged = false;
-  for (OpFoldResult &ofr : ofrs) {
-    if (ofr.is<Attribute>())
-      continue;
-    // Newly static, move from Value to constant.
-    if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
-                         .getDefiningOp<arith::ConstantIndexOp>()) {
-      ofr = b.getIndexAttr(cstOp.value());
-      valuesChanged = true;
-    }
-  }
-  return success(valuesChanged);
-}
-
 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
                                                   ArrayRef<int64_t> shape) {
   llvm::SmallBitVector dimsToProject(shape.size());

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index aaa5e39cd2f3d2..cfa09b5d4e11de 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1513,9 +1513,9 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
     SmallVector<OpFoldResult> mixedLowerBound(op.getMixedLowerBound());
     SmallVector<OpFoldResult> mixedUpperBound(op.getMixedUpperBound());
     SmallVector<OpFoldResult> mixedStep(op.getMixedStep());
-    if (failed(foldDynamicIndexList(rewriter, mixedLowerBound)) &&
-        failed(foldDynamicIndexList(rewriter, mixedUpperBound)) &&
-        failed(foldDynamicIndexList(rewriter, mixedStep)))
+    if (failed(foldDynamicIndexList(mixedLowerBound)) &&
+        failed(foldDynamicIndexList(mixedUpperBound)) &&
+        failed(foldDynamicIndexList(mixedStep)))
       return failure();
 
     rewriter.updateRootInPlace(op, [&]() {

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8fc3494f2d2e83..acd6a7271bce41 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2317,9 +2317,9 @@ class InsertSliceOpConstantArgumentFolder final
     SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
 
     // No constant operands were folded, just return;
-    if (failed(foldDynamicIndexList(rewriter, mixedOffsets)) &&
-        failed(foldDynamicIndexList(rewriter, mixedSizes)) &&
-        failed(foldDynamicIndexList(rewriter, mixedStrides)))
+    if (failed(foldDynamicIndexList(mixedOffsets)) &&
+        failed(foldDynamicIndexList(mixedSizes)) &&
+        failed(foldDynamicIndexList(mixedStrides)))
       return failure();
 
     // Create the new op in canonical form.

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 2e0bafb4fc6545..8a4ccc990331a7 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -256,4 +256,18 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
 }
 
+LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs) {
+  bool valuesChanged = false;
+  for (OpFoldResult &ofr : ofrs) {
+    if (ofr.is<Attribute>())
+      continue;
+    Attribute attr;
+    if (matchPattern(ofr.get<Value>(), m_Constant(&attr))) {
+      ofr = attr;
+      valuesChanged = true;
+    }
+  }
+  return success(valuesChanged);
+}
+
 } // namespace mlir


        


More information about the Mlir-commits mailing list