[Mlir-commits] [mlir] [mlir][Affine] Let affine.[de]linearize_index omit outer bounds (PR #116103)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 13 12:25:57 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
The affine.delinearize_index and affine.linearize_index operations, as currently defined, require providing a length N basis to [de]linearize N values. The first value in this basis is never used during lowering and is unused during lowering. (Note that, even though it isn't used during lowering it can still be used to, for example, remove length-1 outputs from a delinearize).
This dead value makes sense in the original context of these operations, which is linearizing or de-linearizing indexes to memref<>s, vector<>s, and other shaped types, where that outer bound is avaliable and may be useful for analysis.
However, other usecases exist where the outer bound is not known. For example:
%thread_id_x = gpu.thread_id x : index
%0:3 = affine.delinearize_index %thread_id_x into (4, 16) : index,index, index
In this code, we don't know the upper bound of the thread ID, but we do want to construct the ?x4x16 grid of delinearized values in order to further partition the GPU threads.
In order to support such usecases, we broaden the definition of affine.delinearize_index and affine.linearize_index to make the outer bound optional.
In the case of affine.delinearize_index, where the number of results is a function of the size of the passed-in basis, we augment all existing builders with a `hasOuterBound` argument, which, for backwards compatibilty and to preserve the natural usage of the op, defaults to `true`. If this flag is true, the op returns one result per basis element, if it is false, it returns one extra result in position 0.
We also update existing canonicalization patterns (and move one of them into the folder) to handle these cases. Note that disagreements about the outer bound now no longer prevent delinearize/linearize cancelations.
---
Patch is 41.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116103.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+75-9)
- (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+10-4)
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+95-86)
- (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+9-2)
- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+19-5)
- (modified) mlir/test/Dialect/Affine/affine-expand-index-ops.mlir (+7-6)
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+172)
- (modified) mlir/test/Dialect/Affine/invalid.mlir (+5-13)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 6a495e11ae1ad5..53bc7ce0349241 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1059,8 +1059,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
// AffineDelinearizeIndexOp
//===----------------------------------------------------------------------===//
-def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
let summary = "delinearize an index";
let description = [{
The `affine.delinearize_index` operation takes a single index value and
@@ -1082,6 +1081,25 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
%indices_1 = affine.apply #map1()[%linear_index]
%indices_2 = affine.apply #map2()[%linear_index]
```
+
+ The basis may either contain `N` or `N-1` elements, where `N` is the nubrer of results.
+ If there are N basis elements, the first one will not be used during computations,
+ but may be used during analysis and canonicalization to eliminate terms from
+ the `affine.delinearize_index` or to enable conclusions about the total size of
+ `%linear_index`.
+
+ If the basis is fully provided, the delinearize_index operation is said to "have
+ an outer bound". The builders assume that an `affine.delinearize_index` has
+ an outer bound by default, as this is how the operation was initially defined.
+
+ That is, the example above could also have been written
+ ```mlir
+ %0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
+ ```
+
+ Note that, due to the constraints of affine maps, all the basis elements must
+ be strictly positive. A dynamic basis element being 0 or negative causes
+ undefined behavior.
}];
let arguments = (ins Index:$linear_index,
@@ -1096,17 +1114,37 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
}];
let builders = [
- OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
- OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
- OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
+ OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_asis, CArg<"bool", "true">:$hasOuterBound)>,
+ OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis, CArg<"bool", "true">:$hasOuterBound)>,
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "true">:$hasOuterBound)>,
+ OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "true">:$hasOuterBound)>
];
let extraClassDeclaration = [{
+ /// Return true if the basis includes a bound on the first index input.
+ bool hasOuterBound() {
+ return getMultiIndex().size() == getStaticBasis().size();
+ }
+
/// Returns a vector with all the static and dynamic basis values.
SmallVector<OpFoldResult> getMixedBasis() {
OpBuilder builder(getContext());
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}
+
+ /// Return a vector that contains the basis of the operation, removing
+ /// the outer bound if one is present.
+ SmallVector<OpFoldResult> getEffectiveBasis() {
+ OpBuilder builder(getContext());
+ if (hasOuterBound()) {
+ if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+ return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis().drop_front(), builder);
+
+ return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), builder);
+ }
+
+ return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ }
}];
let hasVerifier = 1;
@@ -1124,13 +1162,21 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
The `affine.linearize_index` operation takes a sequence of index values and a
basis of the same length and linearizes the indices using that basis.
- That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
- it computes
+ That is, for indices `%idx_0` to `%idx_{N-1}` and basis elements `b_0`
+ (or `b_1`) up to `b_{N-1}` it computes
```
- sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
+ sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
```
+ The basis may either have `N` or `N-1` elements, where `N` is the number of
+ inputs to linearize_index. If `N` inputs are provided, the first one is not used
+ in computation, but may be used during analysis or canonicalization as a bound
+ on `%idx_0`.
+
+ If all `N` basis elements are provided, the linearize_index operation is said to
+ "have an outer bound".
+
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1140,7 +1186,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
Example:
```mlir
- %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
+ %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (2, 3, 5) : index
+ // Same effect
+ %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (3, 5) : index
```
In the above example, `%linear_index` conceptually holds the following:
@@ -1171,12 +1219,30 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
];
let extraClassDeclaration = [{
+ /// Return true if the basis includes a bound on the first index input.
+ bool hasOuterBound() {
+ return getMultiIndex().size() == getStaticBasis().size();
+ }
+
/// Return a vector with all the static and dynamic basis values.
SmallVector<OpFoldResult> getMixedBasis() {
OpBuilder builder(getContext());
return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}
+ /// Return a vector that contains the basis of the operation, removing
+ /// the outer bound if one is present.
+ SmallVector<OpFoldResult> getEffectiveBasis() {
+ OpBuilder builder(getContext());
+ if (hasOuterBound()) {
+ if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+ return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis().drop_front(), builder);
+
+ return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), builder);
+ }
+
+ return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ }
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 0e98223969e08c..0f801ebb6f5898 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -307,17 +307,23 @@ struct DivModValue {
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
/// Generate the IR to delinearize `linearIndex` given the `basis` and return
-/// the multi-index.
+/// the multi-index. `hasOuterBound` indicates whether `basis` has an entry
+/// given the size of the first multi-index result - if it is true, the function
+/// will return `basis.size()` values, otherwise, it will return `basis.size() +
+/// 1`.
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
- ArrayRef<Value> basis);
+ ArrayRef<Value> basis,
+ bool hasOuterBound = true);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
- ArrayRef<OpFoldResult> basis);
+ ArrayRef<OpFoldResult> basis,
+ bool hasOuterBound = true);
// Generate IR that extracts the linear index from a multi-index according to
-// a basis/shape.
+// a basis/shape. The basis may contain either `multiIndex.size()` or
+// `multiIndex.size() - 1` elements.
OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
ArrayRef<OpFoldResult> basis,
ImplicitLocOpBuilder &builder);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index fbc9053a0e273b..3693195c39fecb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
@@ -4503,62 +4504,81 @@ LogicalResult AffineVectorStoreOp::verify() {
// DelinearizeIndexOp
//===----------------------------------------------------------------------===//
-LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
- MLIRContext *context, std::optional<::mlir::Location> location,
- ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
- RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
- AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
- regions);
- inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
- IndexType::get(context));
- return success();
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
+ Value linearIndex, ValueRange dynamicBasis,
+ ArrayRef<int64_t> staticBasis,
+ bool hasOuterBound) {
+ SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
+ : staticBasis.size() + 1,
+ linearIndex.getType());
+ build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
+ staticBasis);
}
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
- Value linearIndex, ValueRange basis) {
+ Value linearIndex, ValueRange basis,
+ bool hasOuterBound) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
staticBasis);
- build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+ hasOuterBound);
}
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex,
- ArrayRef<OpFoldResult> basis) {
+ ArrayRef<OpFoldResult> basis,
+ bool hasOuterBound) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
- build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+ hasOuterBound);
}
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
- Value linearIndex,
- ArrayRef<int64_t> basis) {
- build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
+ Value linearIndex, ArrayRef<int64_t> basis,
+ bool hasOuterBound) {
+ build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
}
LogicalResult AffineDelinearizeIndexOp::verify() {
- if (getStaticBasis().empty())
- return emitOpError("basis should not be empty");
- if (getNumResults() != getStaticBasis().size())
- return emitOpError("should return an index for each basis element");
- auto dynamicMarkersCount =
- llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+ ArrayRef<int64_t> staticBasis = getStaticBasis();
+ if (getNumResults() != staticBasis.size() &&
+ getNumResults() != staticBasis.size() + 1)
+ return emitOpError("should return an index for each basis element and up "
+ "to one extra index");
+
+ auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
return emitOpError(
"mismatch between dynamic and static basis (kDynamic marker but no "
"corresponding dynamic basis entry) -- this can only happen due to an "
"incorrect fold/rewrite");
+
+ if (!llvm::all_of(staticBasis, [](int64_t v) {
+ return v > 0 || ShapedType::isDynamic(v);
+ }))
+ return emitOpError("no basis element may be statically non-positive");
+
return success();
}
LogicalResult
AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) {
+ // If we won't be doing any division or modulo (no basis or the one basis
+ // element is purely advisory), simply return the input value.
+ if (getStaticBasis().size() == static_cast<size_t>(hasOuterBound())) {
+ result.push_back(getLinearIndex());
+ return success();
+ }
+
if (adaptor.getLinearIndex() == nullptr)
return failure();
@@ -4567,7 +4587,11 @@ AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
Type attrType = getLinearIndex().getType();
- for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+
+ ArrayRef<int64_t> staticBasis = getStaticBasis();
+ if (hasOuterBound())
+ staticBasis = staticBasis.drop_front();
+ for (int64_t modulus : llvm::reverse(staticBasis)) {
result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
highPart = llvm::divideFloorSigned(highPart, modulus);
}
@@ -4594,24 +4618,25 @@ struct DropUnitExtentBasis
return zero.value();
};
+ bool hasOuterBound = delinearizeOp.hasOuterBound();
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
- SmallVector<OpFoldResult> newOperands;
+ SmallVector<OpFoldResult> newBasis;
for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
std::optional<int64_t> basisVal = getConstantIntValue(basis);
if (basisVal && *basisVal == 1)
- replacements[index] = getZero();
+ replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
else
- newOperands.push_back(basis);
+ newBasis.push_back(basis);
}
- if (newOperands.size() == delinearizeOp.getStaticBasis().size())
+ if (newBasis.size() == delinearizeOp.getStaticBasis().size())
return rewriter.notifyMatchFailure(delinearizeOp,
"no unit basis elements");
- if (!newOperands.empty()) {
+ if (!newBasis.empty() || !hasOuterBound) {
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, delinearizeOp.getLinearIndex(), newOperands);
+ loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
@@ -4626,27 +4651,6 @@ struct DropUnitExtentBasis
}
};
-/// Drop delinearization with a single basis element
-///
-/// By definition, `delinearize_index %linear into (%basis)` is
-/// `%linear floorDiv 1` (since `1` is the product of the basis elememts,
-/// ignoring the 0th one, and since there is no previous division we need
-/// to use the remainder of). Therefore, a single-element `delinearize`
-/// can be replaced by the underlying linear index.
-struct DropDelinearizeOneBasisElement
- : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
- PatternRewriter &rewriter) const override {
- if (delinearizeOp.getStaticBasis().size() != 1)
- return rewriter.notifyMatchFailure(delinearizeOp,
- "doesn't have a length-1 basis");
- rewriter.replaceOp(delinearizeOp, delinearizeOp.getLinearIndex());
- return success();
- }
-};
-
/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
/// disjoint` and the two operations have the same basis, replace the
/// delinearizeation results with the inputs of the `affine.linearize_index`
@@ -4668,7 +4672,7 @@ struct CancelDelinearizeOfLinearizeDisjointExact
"index doesn't come from linearize");
if (!linearizeOp.getDisjoint() ||
- linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+ linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
return rewriter.notifyMatchFailure(
linearizeOp, "not disjoint or basis doesn't match delinearize");
@@ -4680,8 +4684,9 @@ struct CancelDelinearizeOfLinearizeDisjointExact
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.insert<CancelDelinearizeOfLinearizeDisjointExact,
- DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+ patterns
+ .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
+ context);
}
//===----------------------------------------------------------------------===//
@@ -4718,11 +4723,11 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
}
LogicalResult AffineLinearizeIndexOp::verify() {
- if (getStaticBasis().empty())
- return emitOpError("basis should not be empty");
-
- if (getMultiIndex().size() != getStaticBasis().size())
- return emitOpError("should be passed an index for each basis element");
+ size_t numIndexes = getMultiIndex().size();
+ size_t numBasisElems = getStaticBasis().size();
+ if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
+ return emitOpError("should be passed a basis element for each index except "
+ "possibly the first");
auto dynamicMarkersCount =
llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
@@ -4736,6 +4741,14 @@ LogicalResult AffineLinearizeIndexOp::verify() {
}
OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+ // No indices linearizes to zero.
+ if (getMultiIndex().empty())
+ return IntegerAttr::get(getResult().getType(), 0);
+
+ // One single index linearizes to itself.
+ if (getMultiIndex().size() == 1)
+ return getMultiIndex().front();
+
if (llvm::any_of(adaptor.getMultiIndex(),
[](Attribute a) { return a == nullptr; }))
return nullptr;
@@ -4745,12 +4758,17 @@ OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
int64_t result = 0;
int64_t stride = 1;
- for (auto [indexAttr, length] :
- llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
- llvm::reverse(getStaticBasis()))) {
+ for (auto [length, indexAttr] :
+ llvm::zip_first(llvm::reverse(getStaticBasis()),
+ llvm::reverse(adaptor.getMultiIndex()))) {
result = ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/116103
More information about the Mlir-commits
mailing list