[Mlir-commits] [mlir] [mlir][Affine] Extend linearize/delinearize cancelation to partial tails (PR #116872)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 19 12:38:48 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
xisting patterns would cancel out the linearize_index / delinearize_index pairs that had the exact same basis, like
%0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
%1:4 = affine.delinearize_index %0 into (W, X, Y, Z) : index, ...
This commit extends the canonicalization to handle instances where the entire basis doesn't match, as in
%0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
%1:3 = affine.delinearize_index %0 into (XY, Y, Z) : index, ...
where we can replace the last two results of the delinearize_index operation with the last two inputs of the linearize_index, creating a more canonical (fewer total computations to perform) result.
---
Full diff: https://github.com/llvm/llvm-project/pull/116872.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+45-11)
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+18)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..67d7da622a3550 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4666,14 +4666,16 @@ struct DropUnitExtentBasis
};
/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
-/// disjoint` and the two operations have the same basis, replace the
-/// delinearizeation results with the inputs of the `affine.linearize_index`
-/// since they are exact inverses of each other.
+/// disjoint` and the two operations end with the same basis elements,
+/// cancel those parts of the operations out because they are inverses
+/// of each other.
+///
+/// If the operations have the same basis, cancel them entirely.
///
/// The `disjoint` flag is needed on the `affine.linearize_index` because
/// otherwise, there is no guarantee that the inputs to the linearization are
/// in-bounds the way the outputs of the delinearization would be.
-struct CancelDelinearizeOfLinearizeDisjointExact
+struct CancelDelinearizeOfLinearizeDisjointExactTail
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
@@ -4685,12 +4687,45 @@ struct CancelDelinearizeOfLinearizeDisjointExact
return rewriter.notifyMatchFailure(delinearizeOp,
"index doesn't come from linearize");
- if (!linearizeOp.getDisjoint() ||
- linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
+ if (!linearizeOp.getDisjoint())
+ return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
+
+ ValueRange linearizeIns = linearizeOp.getMultiIndex();
+ // Note: we use the full basis so we don't lose outer bounds later.
+ SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
+ SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
+ size_t numMatches = 0;
+ for (auto [linSize, delinSize] : llvm::zip(
+ llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
+ if (linSize != delinSize)
+ break;
+ ++numMatches;
+ }
+
+ if (numMatches == 0)
return rewriter.notifyMatchFailure(
- linearizeOp, "not disjoint or basis doesn't match delinearize");
+ delinearizeOp, "final basis element doesn't match linearize");
+
+ // The easy case: everything lines up and the basis match sup completely.
+ if (numMatches == linearizeBasis.size() &&
+ numMatches == delinearizeBasis.size() &&
+ linearizeIns.size() == delinearizeOp.getNumResults()) {
+ rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+ return success();
+ }
- rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+ Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
+ linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
+ ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
+ linearizeOp.getDisjoint());
+ auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+ delinearizeOp.getLoc(), newLinearize,
+ ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
+ delinearizeOp.hasOuterBound());
+ SmallVector<Value> mergedResults(newDelinearize.getResults());
+ mergedResults.append(linearizeIns.take_back(numMatches).begin(),
+ linearizeIns.take_back(numMatches).end());
+ rewriter.replaceOp(delinearizeOp, mergedResults);
return success();
}
};
@@ -4698,9 +4733,8 @@ struct CancelDelinearizeOfLinearizeDisjointExact
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns
- .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
- context);
+ patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
+ DropUnitExtentBasis>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..5384977151b47f 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1739,6 +1739,24 @@ func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0:
// -----
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_partial(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (%[[ARG3]], 4) : index
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[LIN]] into (8) : index, index
+// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_partial(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+ %1:3 = affine.delinearize_index %0 into (8, %arg4)
+ : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
``````````
</details>
https://github.com/llvm/llvm-project/pull/116872
More information about the Mlir-commits
mailing list