[Mlir-commits] [mlir] [MLIR][Transform] FuseOp: accept transform params, add use_forall argument (PR #161883)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Oct 7 04:24:30 PDT 2025
================
@@ -665,24 +759,69 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
}
LogicalResult transform::FuseOp::verify() {
- SmallVector<int64_t> permutation =
- extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
- auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
- if (!std::is_permutation(sequence.begin(), sequence.end(),
- permutation.begin(), permutation.end())) {
- return emitOpError() << "expects interchange to be a permutation, found "
- << getTileInterchange();
+ ArrayRef<int64_t> permutation = getStaticTileInterchange();
+ if (!llvm::any_of(permutation,
+ [](int64_t v) { return ShapedType::isDynamic(v); })) {
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
+ if (!std::is_permutation(sequence.begin(), sequence.end(),
+ permutation.begin(), permutation.end())) {
+ return emitOpError() << "expects interchange to be a permutation, found "
+ << getTileInterchange();
+ }
}
- SmallVector<int64_t> sizes =
- extractFromIntegerArrayAttr<int64_t>(getTileSizes());
- size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
+ ArrayRef<int64_t> sizes = getStaticTileSizes();
+ size_t numExpectedLoops =
+ getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
if (numExpectedLoops != getNumResults() - 1)
return emitOpError() << "expects " << numExpectedLoops << " loop results";
return success();
}
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
+ ValueRange dynamicValues = getTileSizes();
+ ArrayRef<int64_t> staticValues = getStaticTileSizes();
+ SmallVector<OpFoldResult> results;
+ results.reserve(staticValues.size());
+ unsigned dynamicPos = 0;
+ Builder builder(getContext());
+ for (int64_t size : staticValues) {
+ if (size == ShapedType::kDynamic) {
+ results.push_back(dynamicValues[dynamicPos++]);
+ } else {
+ results.push_back(builder.getIndexAttr(size));
+ }
+ }
+ return results;
+}
+
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
+ ValueRange dynamicValues = getTileInterchange();
+ ArrayRef<int64_t> staticValues = getStaticTileInterchange();
+ SmallVector<OpFoldResult> results;
+ results.reserve(staticValues.size());
+ unsigned dynamicPos = 0;
+ Builder builder(getContext());
+ for (int64_t size : staticValues) {
+ if (size == ShapedType::kDynamic) {
+ results.push_back(dynamicValues[dynamicPos++]);
+ } else {
+ results.push_back(builder.getIndexAttr(size));
+ }
+ }
+ return results;
----------------
fschlimb wrote:
return getMixedValues(getStaticTileInterchange(), getTileInterchange(), getContext());
https://github.com/llvm/llvm-project/pull/161883
More information about the Mlir-commits
mailing list