[Mlir-commits] [mlir] [mlir][gpu] Allow WarpOpDeadResult, WarpOpForwardOperand patterns to be used in isolation. (PR #132860)
Charitha Saumya
llvmlistbot at llvm.org
Mon Mar 24 19:19:37 PDT 2025
https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/132860
This PR move `WarpOpDeadResult`, `WarpOpForwardOperand` patterns into `populateWarpSimplificationPatterns` so that they can be reused without having to reuse all other vector distribution patterns inside `populatePropagateWarpVectorDistributionPatterns`.
>From 55c272c367ad296631db90740ed736e1eb7ea1e4 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 25 Mar 2025 02:16:03 +0000
Subject: [PATCH] save work
---
.../Vector/Transforms/VectorDistribution.h | 4 ++++
.../Vector/Transforms/VectorDistribute.cpp | 16 +++++++++++-----
.../lib/Dialect/Vector/TestVectorTransforms.cpp | 2 ++
3 files changed, 17 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index dda45219b2acc..082d990cee8a4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -98,6 +98,10 @@ void populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
+/// Patterns for simplification of WarpExecuteOnLane0Op during distribution.
+void populateWarpSimplificationPatterns(RewritePatternSet &pattern,
+ PatternBenefit benefit = 1);
+
/// Lambda signature to compute a reduction of a distributed value for the given
/// reduction kind and size.
using DistributedReductionFn =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..f0d771142e307 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1761,17 +1761,23 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
- patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
- WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
- WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
- patterns.getContext(), benefit);
+ patterns
+ .add<WarpOpElementwise, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+ WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+ patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
+void mlir::vector::populateWarpSimplificationPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<WarpOpDeadResult, WarpOpForwardOperand>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateDistributeReduction(
RewritePatternSet &patterns,
const DistributedReductionFn &distributedReductionFn,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..feec10e6492f7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -660,6 +660,7 @@ struct TestVectorDistribution
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn, /*benefit=*/1,
/*readBenefit=*/0);
+ vector::populateWarpSimplificationPatterns(patterns);
vector::populateDistributeReduction(patterns, warpReduction, 1);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
@@ -672,6 +673,7 @@ struct TestVectorDistribution
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn);
+ vector::populateWarpSimplificationPatterns(patterns);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
More information about the Mlir-commits
mailing list