[Mlir-commits] [mlir] c6f67b8 - [mlir][affine] Add ValueBoundsOpInterface to [de]linearize_index (#121833)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 7 14:28:18 PST 2025
Author: Krzysztof Drewniak
Date: 2025-01-07T16:28:14-06:00
New Revision: c6f67b8e39a907fb96b715cae3ee90e4c1b248aa
URL: https://github.com/llvm/llvm-project/commit/c6f67b8e39a907fb96b715cae3ee90e4c1b248aa
DIFF: https://github.com/llvm/llvm-project/commit/c6f67b8e39a907fb96b715cae3ee90e4c1b248aa.diff
LOG: [mlir][affine] Add ValueBoundsOpInterface to [de]linearize_index (#121833)
Since a need for it came up dowstream (in proving that loops run at
least once), this commit implements the ValueBoundsOpInterface for
affine.delinearize_index and affine.linearize_index, using affine map
representations of the operations they perform.
These implementations also use information from outer bounds to impose
additional constraints when those are available.
Added:
Modified:
mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index 82a9fb0d490882..e93b99b4f49866 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -91,6 +91,64 @@ struct AffineMaxOpInterface
};
};
+struct AffineDelinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineDelinearizeIndexOp>(rawOp);
+ auto result = cast<OpResult>(value);
+ assert(result.getOwner() == rawOp &&
+ "bounded value isn't a result of this delinearize_index");
+ unsigned resIdx = result.getResultNumber();
+
+ AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
+
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ AffineExpr divisor = cstr.getExpr(1);
+ for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1))
+ divisor = divisor * cstr.getExpr(basisElem);
+
+ if (resIdx == 0) {
+ cstr.bound(value) == linearIdx.floorDiv(divisor);
+ if (!basis.front().isNull())
+ cstr.bound(value) < cstr.getExpr(basis.front());
+ return;
+ }
+ AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
+ cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
+ }
+};
+
+struct AffineLinearizeIndexOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<
+ AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
+ void populateBoundsForIndexValue(Operation *rawOp, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ auto op = cast<AffineLinearizeIndexOp>(rawOp);
+ assert(value == op.getResult() &&
+ "value isn't the result of this linearize");
+
+ AffineExpr bound = cstr.getExpr(0);
+ AffineExpr stride = cstr.getExpr(1);
+ SmallVector<OpFoldResult> basis = op.getPaddedBasis();
+ OperandRange multiIndex = op.getMultiIndex();
+ unsigned numArgs = multiIndex.size();
+ for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
+ unsigned argNum = numArgs - (revArgNum + 1);
+ if (argNum == 0)
+ break;
+ OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
+ bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
+ stride = stride * cstr.getExpr(length);
+ }
+ bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
+ cstr.bound(value) == bound;
+ if (op.getDisjoint() && !basis.front().isNull()) {
+ cstr.bound(value) < stride *cstr.getExpr(basis.front());
+ }
+ }
+};
} // namespace
} // namespace mlir
@@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
+ AffineDelinearizeIndexOp::attachInterface<
+ AffineDelinearizeIndexOpInterface>(*ctx);
+ AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
+ *ctx);
});
}
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 935c08aceff548..5354eb38d7b039 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -155,3 +155,84 @@ func.func @compare_maps(%a: index, %b: index) {
: (index, index, index, index) -> ()
return
}
+
+// -----
+
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
+// CHECK-LABEL: func.func @delinearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index)
+// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
+// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
+// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
+// CHECK: return %[[v1]], %[[v2]], %[[v3]]
+func.func @delinearize_static(%arg0: index) -> (index, index, index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index
+ %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+ %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+ %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
+ // expected-remark @below{{true}}
+ "test.compare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
+ return %1, %2, %3 : index, index, index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
+// CHECK-LABEL: func.func @delinearize_static_no_outer_bound
+// CHECK-SAME: (%[[arg0:.+]]: index)
+// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
+// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
+// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
+// CHECK: return %[[v1]], %[[v2]], %[[v3]]
+func.func @delinearize_static_no_outer_bound(%arg0: index) -> (index, index, index) {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
+ %1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
+ %2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
+ %3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
+ "test.compaare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
+ return %1, %2, %3 : index, index, index
+}
+
+// -----
+
+// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-LABEL: func.func @linearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
+// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
+// CHECK: return %[[v1]]
+func.func @linearize_static(%arg0: index, %arg1: index) -> index {
+ %c6 = arith.constant 6 : index
+ %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 3) : index
+ %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+ // expected-remark @below{{true}}
+ "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-LABEL: func.func @linearize_static_no_outer_bound
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
+// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
+// CHECK: return %[[v1]]
+func.func @linearize_static_no_outer_bound(%arg0: index, %arg1: index) -> index {
+ %c6 = arith.constant 6 : index
+ %0 = affine.linearize_index disjoint [%arg0, %arg1] by (3) : index
+ %1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
+ // expected-error @below{{unknown}}
+ "test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
+ return %1 : index
+}
More information about the Mlir-commits
mailing list