[llvm-branch-commits] [mlir] [mlir][linalg] Migrate Detensorize pass to new dialect conversion driver (PR #152912)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sun Aug 10 04:43:17 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/152912
The pass used to access erased operations and block arguments in the type converter. That is no longer supported in the new conversion driver.
>From c2e90f3a39148223619497eeff16ed810e3cab95 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 10 Aug 2025 11:41:51 +0000
Subject: [PATCH] [mlir][linalg] Migrate Detensorize pass to new dialect
conversion driver
---
.../Dialect/Linalg/Transforms/Detensorize.cpp | 34 +++++++++++++++++--
mlir/test/Dialect/Linalg/detensorize_0d.mlir | 7 ++--
2 files changed, 36 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..221f95a8d8f33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -458,6 +458,22 @@ struct LinalgDetensorize
}
};
+ /// A listener that forwards notifyBlockErased and notifyOperationErased to
+ /// the given callbacks.
+ struct CallbackListener : public RewriterBase::Listener {
+ CallbackListener(std::function<void(Operation *op)> onOperationErased,
+ std::function<void(Block *block)> onBlockErased)
+ : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
+
+ void notifyBlockErased(Block *block) override { onBlockErased(block); }
+ void notifyOperationErased(Operation *op) override {
+ onOperationErased(op);
+ }
+
+ std::function<void(Operation *op)> onOperationErased;
+ std::function<void(Block *block)> onBlockErased;
+ };
+
void runOnOperation() override {
MLIRContext *context = &getContext();
DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,22 @@ struct LinalgDetensorize
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
shouldConvertBranchOperand);
- if (failed(
- applyFullConversion(getOperation(), target, std::move(patterns))))
+ ConversionConfig config;
+ auto onOperationErased = [&](Operation *op) {
+ opsToDetensor.erase(op);
+ detensorableBranchOps.erase(op);
+ };
+ auto onBlockErased = [&](Block *block) {
+ for (BlockArgument arg : block->getArguments()) {
+ blockArgsToDetensor.erase(arg);
+ }
+ };
+ CallbackListener listener(onOperationErased, onBlockErased);
+
+ config.listener = &listener;
+ config.allowPatternRollback = false;
+ if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
+ config)))
signalPassFailure();
RewritePatternSet canonPatterns(context);
diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
index 74931cb0830bc..5c29b04630cad 100644
--- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
@@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso
}
// CHECK-LABEL: func @detensor_op_sequence
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
-// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
+// CHECK-DAG: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
-// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
-// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
+// CHECK-DAG: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
+// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
+// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]]
// CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
// CHECK: return %[[new_tensor_res]]
More information about the llvm-branch-commits
mailing list