[Mlir-commits] [mlir] [MLIR][Affine] Add canonicalization pattern to have constant basis attr for affine.delinearize_index/linearize_index (PR #117572)
Abhishek Varma
llvmlistbot at llvm.org
Mon Nov 25 22:38:42 PST 2024
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/117572
>From 7dc8a7d83f875d8357cc40c80d59bdd54842839b Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 25 Nov 2024 15:47:26 +0000
Subject: [PATCH 1/2] [MLIR][Affine] Add canonicalization pattern to have CST
basis attr
-- This commit adds canonicalization pattern to have constant(CST)
attribute for affine.delinearize_index/linearize_index op's basis
wherever applicable.
-- Essentially the patterns check if the mixed basis OpFoldResult
set contains any constant SSA value and converts it to a constant
integer attribute instead.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 66 +++++++++++++++++++++-
mlir/test/Dialect/Affine/canonicalize.mlir | 30 ++++++++++
2 files changed, 94 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 67d7da622a3550..3e82ec00763142 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4729,12 +4729,55 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
return success();
}
};
+
+/// Give mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with constant attribute as OpFoldResult. In case no
+/// change is made to the existing mixed basis set, return failure; success
+/// otherwise.
+static LogicalResult
+fetchNewConstantBasis(PatternRewriter &rewriter,
+ SmallVector<OpFoldResult> mixedBasis,
+ SmallVector<OpFoldResult> &newBasis) {
+ // Replace all constant SSA values with the constant attribute.
+ bool hasConstantSSAVal = false;
+ for (OpFoldResult basis : mixedBasis) {
+ std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ if (basisVal && !isa<Attribute>(basis)) {
+ newBasis.push_back(rewriter.getIndexAttr(*basisVal));
+ hasConstantSSAVal = true;
+ } else {
+ newBasis.push_back(basis);
+ }
+ }
+ if (hasConstantSSAVal)
+ return success();
+ return failure();
+}
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisDelinearizeIndexOpPattern
+ : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace all constant SSA values with the constant attribute.
+ SmallVector<OpFoldResult> newBasis;
+ if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+ return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+ rewriter.replaceOpWithNewOp<affine::AffineDelinearizeIndexOp>(
+ op, op.getLinearIndex(), newBasis, op.hasOuterBound());
+ return success();
+ }
+};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
- DropUnitExtentBasis>(context);
+ DropUnitExtentBasis,
+ ConstantAttributeBasisDelinearizeIndexOpPattern>(context);
}
//===----------------------------------------------------------------------===//
@@ -4959,12 +5002,31 @@ struct DropLinearizeLeadingZero final
return success();
}
};
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisLinearizeIndexOpPattern
+ : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace all constant SSA values with the constant attribute.
+ SmallVector<OpFoldResult> newBasis;
+ if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+ return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+ rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+ op, op.getMultiIndex(), newBasis, op.getDisjoint());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
- DropLinearizeUnitComponentsIfDisjointOrZero>(context);
+ DropLinearizeUnitComponentsIfDisjointOrZero,
+ ConstantAttributeBasisLinearizeIndexOpPattern>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 5384977151b47f..16cbce35aeec7e 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1946,3 +1946,33 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
return %ret : index
}
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[RET:.*]]:2 = affine.delinearize_index %[[ARG0]] into (3, 4) : index, index
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[C0]] : index, index, index
+func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
+ (index, index, index) {
+ %c4 = arith.constant 4 : index
+ %c3 = arith.constant 3 : index
+ %c1 = arith.constant 1 : index
+ %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c1)
+ : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (3, 4) : index
+// CHECK: return %[[RET]] : index
+func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
+ (index) {
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %1 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%c1, 3, %c4) : index
+ return %1 : index
+}
>From a0b8fb021a7600dc86c3ac1688a65174db72d218 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 26 Nov 2024 06:38:06 +0000
Subject: [PATCH 2/2] Review comment
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 97 ++++++++----------------
1 file changed, 33 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3e82ec00763142..e48db6d5f521eb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4569,9 +4569,38 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
return success();
}
+/// Give mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with the constant integer value and returns the
+/// new static basis.
+static SmallVector<int64_t>
+foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
+ MutableOperandRange mutableDynamicBasis,
+ ArrayRef<Attribute> dynamicBasis) {
+ SmallVector<int64_t> staticBasis;
+ for (OpFoldResult basis : mixedBasis) {
+ std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ if (!basisVal)
+ staticBasis.push_back(ShapedType::kDynamic);
+ else
+ staticBasis.push_back(*basisVal);
+ }
+
+ int64_t dynamicBasisIndex = 0;
+ for (OpFoldResult basis : dynamicBasis) {
+ if (basis) {
+ mutableDynamicBasis.erase(dynamicBasisIndex);
+ } else {
+ ++dynamicBasisIndex;
+ }
+ }
+ return staticBasis;
+}
+
LogicalResult
AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) {
+ setStaticBasis(foldCstValueToCstAttrBasis(
+ getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
// If we won't be doing any division or modulo (no basis or the one basis
// element is purely advisory), simply return the input value.
if (getNumResults() == 1) {
@@ -4729,55 +4758,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
return success();
}
};
-
-/// Give mixed basis of affine.delinearize_index/linearize_index replace
-/// constant SSA values with constant attribute as OpFoldResult. In case no
-/// change is made to the existing mixed basis set, return failure; success
-/// otherwise.
-static LogicalResult
-fetchNewConstantBasis(PatternRewriter &rewriter,
- SmallVector<OpFoldResult> mixedBasis,
- SmallVector<OpFoldResult> &newBasis) {
- // Replace all constant SSA values with the constant attribute.
- bool hasConstantSSAVal = false;
- for (OpFoldResult basis : mixedBasis) {
- std::optional<int64_t> basisVal = getConstantIntValue(basis);
- if (basisVal && !isa<Attribute>(basis)) {
- newBasis.push_back(rewriter.getIndexAttr(*basisVal));
- hasConstantSSAVal = true;
- } else {
- newBasis.push_back(basis);
- }
- }
- if (hasConstantSSAVal)
- return success();
- return failure();
-}
-
-/// Folds away constant SSA Value with constant Attribute in basis.
-struct ConstantAttributeBasisDelinearizeIndexOpPattern
- : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp op,
- PatternRewriter &rewriter) const override {
- // Replace all constant SSA values with the constant attribute.
- SmallVector<OpFoldResult> newBasis;
- if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
- return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
-
- rewriter.replaceOpWithNewOp<affine::AffineDelinearizeIndexOp>(
- op, op.getLinearIndex(), newBasis, op.hasOuterBound());
- return success();
- }
-};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
- DropUnitExtentBasis,
- ConstantAttributeBasisDelinearizeIndexOpPattern>(context);
+ DropUnitExtentBasis>(context);
}
//===----------------------------------------------------------------------===//
@@ -4832,6 +4818,8 @@ LogicalResult AffineLinearizeIndexOp::verify() {
}
OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+ setStaticBasis(foldCstValueToCstAttrBasis(
+ getMixedBasis(), getDynamicBasisMutable(), adaptor.getDynamicBasis()));
// No indices linearizes to zero.
if (getMultiIndex().empty())
return IntegerAttr::get(getResult().getType(), 0);
@@ -5002,31 +4990,12 @@ struct DropLinearizeLeadingZero final
return success();
}
};
-
-/// Folds away constant SSA Value with constant Attribute in basis.
-struct ConstantAttributeBasisLinearizeIndexOpPattern
- : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
- PatternRewriter &rewriter) const override {
- // Replace all constant SSA values with the constant attribute.
- SmallVector<OpFoldResult> newBasis;
- if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
- return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
-
- rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
- op, op.getMultiIndex(), newBasis, op.getDisjoint());
- return success();
- }
-};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
- DropLinearizeUnitComponentsIfDisjointOrZero,
- ConstantAttributeBasisLinearizeIndexOpPattern>(context);
+ DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list