[Mlir-commits] [mlir] [mlir][affine] Add folders for delinearize_index and linearize_index (PR #115766)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 11 12:51:10 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

This commit adds implementations of fold() for delinearize_index and linearize_index to constant-fold them away when they have a fully constant basis and constant argument(s).

This commit also adds a canonicalization pattern to linearize_index that causes it to drop leading-zero inputs.

---
Full diff: https://github.com/llvm/llvm-project/pull/115766.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+2) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+68-1) 
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+80) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 1dd9b9a440ecc8..753b8951fb084b 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3d38de4bf1068e..37316632a6a06f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+                               SmallVectorImpl<OpFoldResult> &result) {
+  if (adaptor.getLinearIndex() == nullptr)
+    return failure();
+
+  if (!adaptor.getDynamicBasis().empty())
+    return failure();
+
+  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+  Type attrType = getLinearIndex().getType();
+  for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+    result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+    highPart = llvm::divideFloorSigned(highPart, modulus);
+  }
+  result.push_back(IntegerAttr::get(attrType, highPart));
+  std::reverse(result.begin(), result.end());
+  return success();
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4683,6 +4703,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
   return success();
 }
 
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  if (llvm::any_of(adaptor.getMultiIndex(),
+                   [](Attribute a) { return a == nullptr; }))
+    return nullptr;
+
+  if (!adaptor.getDynamicBasis().empty())
+    return nullptr;
+
+  int64_t result = 0;
+  int64_t stride = 1;
+  for (auto [indexAttr, length] :
+       llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+                       llvm::reverse(getStaticBasis()))) {
+    result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+    stride = stride * length;
+  }
+
+  return IntegerAttr::get(getResult().getType(), result);
+}
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4751,12 +4791,39 @@ struct DropLinearizeOneBasisElement final
     return success();
   }
 };
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    Value leadingIdx = op.getMultiIndex().front();
+    if (!matchPattern(leadingIdx, m_Zero()))
+      return failure();
+
+    if (op.getMultiIndex().size() == 1) {
+      rewriter.replaceOp(op, leadingIdx);
+      return success();
+    }
+
+    SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, op.getMultiIndex().drop_front(),
+        ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero,
-               DropLinearizeOneBasisElement>(context);
+               DropLinearizeOneBasisElement, DropLinearizeLeadingZero>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index fa179744094c67..8ed6ab5c965ca0 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+  %c_22 = arith.constant -22 : index
+  %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
 func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
     (index, index, index, index, index, index) {
   %c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
 
 // -----
 
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_disjoint
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
@@ -1577,3 +1643,17 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
   %ret = affine.linearize_index [%arg0] by (%arg1) : index
   return %ret : index
 }
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+//       CHECK:     return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+  return %ret : index
+}
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/115766


More information about the Mlir-commits mailing list