[Mlir-commits] [mlir] [mlir][Affine] Let affine.[de]linearize_index omit outer bounds (PR #116103)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 13 20:28:07 PST 2024
================
@@ -4503,62 +4504,81 @@ LogicalResult AffineVectorStoreOp::verify() {
// DelinearizeIndexOp
//===----------------------------------------------------------------------===//
-LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
- MLIRContext *context, std::optional<::mlir::Location> location,
- ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
- RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
- AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
- regions);
- inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
- IndexType::get(context));
- return success();
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+ OperationState &odsState,
+ Value linearIndex, ValueRange dynamicBasis,
+ ArrayRef<int64_t> staticBasis,
+ bool hasOuterBound) {
+ SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
+ : staticBasis.size() + 1,
+ linearIndex.getType());
+ build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
+ staticBasis);
}
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
- Value linearIndex, ValueRange basis) {
+ Value linearIndex, ValueRange basis,
+ bool hasOuterBound) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
staticBasis);
- build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+ hasOuterBound);
}
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex,
- ArrayRef<OpFoldResult> basis) {
+ ArrayRef<OpFoldResult> basis,
+ bool hasOuterBound) {
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
- build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+ build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+ hasOuterBound);
}
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
- Value linearIndex,
- ArrayRef<int64_t> basis) {
- build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
+ Value linearIndex, ArrayRef<int64_t> basis,
+ bool hasOuterBound) {
+ build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
}
LogicalResult AffineDelinearizeIndexOp::verify() {
- if (getStaticBasis().empty())
- return emitOpError("basis should not be empty");
- if (getNumResults() != getStaticBasis().size())
- return emitOpError("should return an index for each basis element");
- auto dynamicMarkersCount =
- llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+ ArrayRef<int64_t> staticBasis = getStaticBasis();
+ if (getNumResults() != staticBasis.size() &&
+ getNumResults() != staticBasis.size() + 1)
+ return emitOpError("should return an index for each basis element and up "
+ "to one extra index");
+
+ auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
return emitOpError(
"mismatch between dynamic and static basis (kDynamic marker but no "
"corresponding dynamic basis entry) -- this can only happen due to an "
"incorrect fold/rewrite");
+
+ if (!llvm::all_of(staticBasis, [](int64_t v) {
+ return v > 0 || ShapedType::isDynamic(v);
+ }))
+ return emitOpError("no basis element may be statically non-positive");
+
return success();
}
LogicalResult
AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &result) {
+ // If we won't be doing any division or modulo (no basis or the one basis
+ // element is purely advisory), simply return the input value.
+ if (getStaticBasis().size() == static_cast<size_t>(hasOuterBound())) {
----------------
MaheshRavishankar wrote:
I am not sure I follow this condition?
https://github.com/llvm/llvm-project/pull/116103
More information about the Mlir-commits
mailing list