[Mlir-commits] [mlir] [MLIR][Python] Add structured.fuseop to generator. (PR #120601)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 19 08:28:03 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hugo Trachino (nujaa)
<details>
<summary>Changes</summary>
Implements a python interface for structured fuseOp allowing more freedom on inputs.
---
Full diff: https://github.com/llvm/llvm-project/pull/120601.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (+1-1)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+3-3)
- (modified) mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp (+3-3)
- (modified) mlir/python/mlir/dialects/transform/structured.py (+52)
- (modified) mlir/test/python/dialects/transform_structured_ext.py (+21)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
index 90021ffa7c380bc..efbe5c56a219b37 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h
@@ -162,7 +162,7 @@ struct SparsifierOptions : public PassPipelineOptions<SparsifierOptions> {
}
/// Projects out the options for `createConvertVectorToLLVMPass`.
- ConvertVectorToLLVMPassOptions lowerVectorToLLVMOptions() const {
+ ConvertVectorToLLVMPassOptions convertVectorToLLVMOptions() const {
ConvertVectorToLLVMPassOptions opts{};
opts.reassociateFPReductions = reassociateFPReductions;
opts.force32BitVectorIndices = force32BitVectorIndices;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 55143d5939ba257..842d239cf6a512c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -35,8 +35,8 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-struct LowerVectorToLLVMPass
- : public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
+struct ConvertVectorToLLVMPass
+ : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
using Base::Base;
@@ -58,7 +58,7 @@ struct LowerVectorToLLVMPass
};
} // namespace
-void LowerVectorToLLVMPass::runOnOperation() {
+void ConvertVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index c5eb965884396ae..5e49252c0e57d98 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -76,16 +76,16 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
pm.addPass(memref::createExpandStridedMetadataPass());
pm.addPass(createLowerAffinePass());
- pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+ pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
pm.addPass(createConvertMathToLibmPass());
pm.addPass(createConvertComplexToLibmPass());
- pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+ pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
pm.addPass(createConvertComplexToLLVMPass());
- pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
+ pm.addPass(createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
pm.addPass(createConvertFuncToLLVMPass());
// Finalize GPU code generation.
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 41051c0d5b2ffb6..b97a1aa8a82bfcf 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -139,6 +139,58 @@ def __init__(
ip=ip,
)
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuseOp(FuseOp):
+ """Specialization for FuseOp class."""
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, List[Type], Operation, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ sizes = sizes if sizes else []
+ num_loops = sum(v if v == 0 else 1 for v in sizes)
+
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert (
+ target_or_none is None
+ ), "Cannot construct FuseOp with two targets."
+ else:
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
+ super().__init__(
+ target.type,
+ loop_types,
+ target,
+ tile_sizes=sizes,
+ tile_interchange=interchange,
+ loc=loc,
+ ip=ip,
+ )
+
@_ods_cext.register_operation(_Dialect, replace=True)
class GeneralizeOp(GeneralizeOp):
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 3ea73e8beea3688..551c2fa1e48acd3 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -101,6 +101,27 @@ def testFuseIntoContainingOpCompact(target):
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ at run
+ at create_sequence
+def testFuseOpCompact(target):
+ structured.FuseOp(target, sizes=[4, 8], interchange=[0, 1])
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1]
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+ at run
+ at create_sequence
+def testFuseOpNoArg(target):
+ structured.FuseOp(target)
+ # CHECK-LABEL: TEST: testFuseOpNoArg
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
+ # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
@run
@create_sequence
def testGeneralize(target):
``````````
</details>
https://github.com/llvm/llvm-project/pull/120601
More information about the Mlir-commits
mailing list