[Mlir-commits] [mlir] cb9481d - [mlir][affine] Add folders for delinearize_index and linearize_index (#115766)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 11:05:41 PST 2024


Author: Krzysztof Drewniak
Date: 2024-11-13T13:05:37-06:00
New Revision: cb9481dbf902adc349757eca12a0a09396dc4a23

URL: https://github.com/llvm/llvm-project/commit/cb9481dbf902adc349757eca12a0a09396dc4a23
DIFF: https://github.com/llvm/llvm-project/commit/cb9481dbf902adc349757eca12a0a09396dc4a23.diff

LOG: [mlir][affine] Add folders for delinearize_index and linearize_index (#115766)

This commit adds implementations of fold() for delinearize_index and
linearize_index to constant-fold them away when they have a fully
constant basis and constant argument(s).

This commit also adds a canonicalization pattern to linearize_index that
causes it to drop leading-zero inputs.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index c9d9202ae3cf1a..6a495e11ae1ad5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1110,6 +1110,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
@@ -1179,6 +1180,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
   }];
 
   let hasVerifier = 1;
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index f1b0fe7e645051..fbc9053a0e273b 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4556,6 +4556,26 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
   return success();
 }
 
+LogicalResult
+AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
+                               SmallVectorImpl<OpFoldResult> &result) {
+  if (adaptor.getLinearIndex() == nullptr)
+    return failure();
+
+  if (!adaptor.getDynamicBasis().empty())
+    return failure();
+
+  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
+  Type attrType = getLinearIndex().getType();
+  for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+    result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
+    highPart = llvm::divideFloorSigned(highPart, modulus);
+  }
+  result.push_back(IntegerAttr::get(attrType, highPart));
+  std::reverse(result.begin(), result.end());
+  return success();
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4715,6 +4735,26 @@ LogicalResult AffineLinearizeIndexOp::verify() {
   return success();
 }
 
+OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  if (llvm::any_of(adaptor.getMultiIndex(),
+                   [](Attribute a) { return a == nullptr; }))
+    return nullptr;
+
+  if (!adaptor.getDynamicBasis().empty())
+    return nullptr;
+
+  int64_t result = 0;
+  int64_t stride = 1;
+  for (auto [indexAttr, length] :
+       llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
+                       llvm::reverse(getStaticBasis()))) {
+    result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
+    stride = stride * length;
+  }
+
+  return IntegerAttr::get(getResult().getType(), result);
+}
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4820,11 +4860,39 @@ struct CancelLinearizeOfDelinearizeExact final
     return success();
   }
 };
+
+/// Strip leading zero from affine.linearize_index.
+///
+/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
+/// to `affine.linearize_index [...a] by (...b)` in all cases.
+struct DropLinearizeLeadingZero final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    Value leadingIdx = op.getMultiIndex().front();
+    if (!matchPattern(leadingIdx, m_Zero()))
+      return failure();
+
+    if (op.getMultiIndex().size() == 1) {
+      rewriter.replaceOp(op, leadingIdx);
+      return success();
+    }
+
+    SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, op.getMultiIndex().drop_front(),
+        ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeOneBasisElement,
+  patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+               DropLinearizeOneBasisElement,
                DropLinearizeUnitComponentsIfDisjointOrZero>(context);
 }
 

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 8988f779ad02b4..ec00b31258d072 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1469,6 +1469,45 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @delinearize_fold_constant
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C1]], %[[C1]], %[[C2]]
+func.func @delinearize_fold_constant() -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_fold_negative_constant
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant() -> (index, index, index) {
+  %c_22 = arith.constant -22 : index
+  %0:3 = affine.delinearize_index %c_22 into (2, 3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[RET]]#2
+func.func @delinearize_dont_fold_constant_dynamic_basis(%arg0: index) -> (index, index, index) {
+  %c22 = arith.constant 22 : index
+  %0:3 = affine.delinearize_index %c22 into (2, %arg0, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
 func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
     (index, index, index, index, index, index) {
   %c1 = arith.constant 1 : index
@@ -1535,6 +1574,33 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
 
 // -----
 
+// CHECK-LABEL: @linearize_fold_constants
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants() -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, 3, 5) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_dont_fold_dynamic_basis
+// CHECK: %[[RET:.+]] = affine.linearize_index
+// CHECK: return %[[RET]]
+func.func @linearize_dont_fold_dynamic_basis(%arg0: index) -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (2, %arg0, 5) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_exact(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
@@ -1676,3 +1742,17 @@ func.func @no_cancel_linearize_denearize_
diff erent_basis(%arg0: index, %arg1: in
   %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
   return %1 : index
 }
+
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+//       CHECK:     return %[[RET]]
+func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %arg0, %arg1] by (2, 3, 5) : index
+  return %ret : index
+}
+


        


More information about the Mlir-commits mailing list