[Mlir-commits] [mlir] [MLIR][Transform] FuseOp: accept transform params, add use_forall argument (PR #161883)

Frank Schlimbach llvmlistbot at llvm.org
Tue Oct 7 04:24:30 PDT 2025


================
@@ -665,24 +759,69 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
 }
 
 LogicalResult transform::FuseOp::verify() {
-  SmallVector<int64_t> permutation =
-      extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
-  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
-  if (!std::is_permutation(sequence.begin(), sequence.end(),
-                           permutation.begin(), permutation.end())) {
-    return emitOpError() << "expects interchange to be a permutation, found "
-                         << getTileInterchange();
+  ArrayRef<int64_t> permutation = getStaticTileInterchange();
+  if (!llvm::any_of(permutation,
+                    [](int64_t v) { return ShapedType::isDynamic(v); })) {
+    auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
+    if (!std::is_permutation(sequence.begin(), sequence.end(),
+                             permutation.begin(), permutation.end())) {
+      return emitOpError() << "expects interchange to be a permutation, found "
+                           << getTileInterchange();
+    }
   }
 
-  SmallVector<int64_t> sizes =
-      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
-  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
+  ArrayRef<int64_t> sizes = getStaticTileSizes();
+  size_t numExpectedLoops =
+      getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
   if (numExpectedLoops != getNumResults() - 1)
     return emitOpError() << "expects " << numExpectedLoops << " loop results";
 
   return success();
 }
 
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
+  ValueRange dynamicValues = getTileSizes();
+  ArrayRef<int64_t> staticValues = getStaticTileSizes();
+  SmallVector<OpFoldResult> results;
+  results.reserve(staticValues.size());
+  unsigned dynamicPos = 0;
+  Builder builder(getContext());
+  for (int64_t size : staticValues) {
+    if (size == ShapedType::kDynamic) {
+      results.push_back(dynamicValues[dynamicPos++]);
+    } else {
+      results.push_back(builder.getIndexAttr(size));
+    }
+  }
+  return results;
+}
+
+SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
+  ValueRange dynamicValues = getTileInterchange();
+  ArrayRef<int64_t> staticValues = getStaticTileInterchange();
+  SmallVector<OpFoldResult> results;
+  results.reserve(staticValues.size());
+  unsigned dynamicPos = 0;
+  Builder builder(getContext());
+  for (int64_t size : staticValues) {
+    if (size == ShapedType::kDynamic) {
+      results.push_back(dynamicValues[dynamicPos++]);
+    } else {
+      results.push_back(builder.getIndexAttr(size));
+    }
+  }
+  return results;
----------------
fschlimb wrote:

return getMixedValues(getStaticTileInterchange(), getTileInterchange(), getContext());

https://github.com/llvm/llvm-project/pull/161883


More information about the Mlir-commits mailing list