[Mlir-commits] [mlir] [mlir][gpu] Allow WarpOpDeadResult, WarpOpForwardOperand patterns to be used in isolation. (PR #132860)

Charitha Saumya llvmlistbot at llvm.org
Mon Mar 24 19:19:37 PDT 2025


https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/132860

This PR move `WarpOpDeadResult`, `WarpOpForwardOperand` patterns into `populateWarpSimplificationPatterns`  so that they can be reused without having to reuse all other vector distribution patterns inside `populatePropagateWarpVectorDistributionPatterns`.

>From 55c272c367ad296631db90740ed736e1eb7ea1e4 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 25 Mar 2025 02:16:03 +0000
Subject: [PATCH] save work

---
 .../Vector/Transforms/VectorDistribution.h       |  4 ++++
 .../Vector/Transforms/VectorDistribute.cpp       | 16 +++++++++++-----
 .../lib/Dialect/Vector/TestVectorTransforms.cpp  |  2 ++
 3 files changed, 17 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index dda45219b2acc..082d990cee8a4 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -98,6 +98,10 @@ void populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
     PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
 
+/// Patterns for simplification of WarpExecuteOnLane0Op during distribution.
+void populateWarpSimplificationPatterns(RewritePatternSet &pattern,
+                                        PatternBenefit benefit = 1);
+
 /// Lambda signature to compute a reduction of a distributed value for the given
 /// reduction kind and size.
 using DistributedReductionFn =
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..f0d771142e307 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1761,17 +1761,23 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
     PatternBenefit readBenefit) {
   patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
-  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
-               WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
-               WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
-      patterns.getContext(), benefit);
+  patterns
+      .add<WarpOpElementwise, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
+           WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
+           WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
+          patterns.getContext(), benefit);
   patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
                                     benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
 }
 
+void mlir::vector::populateWarpSimplificationPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<WarpOpDeadResult, WarpOpForwardOperand>(patterns.getContext(),
+                                                       benefit);
+}
+
 void mlir::vector::populateDistributeReduction(
     RewritePatternSet &patterns,
     const DistributedReductionFn &distributedReductionFn,
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a54ae816570a8..feec10e6492f7 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -660,6 +660,7 @@ struct TestVectorDistribution
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn, /*benefit=*/1,
           /*readBenefit=*/0);
+      vector::populateWarpSimplificationPatterns(patterns);
       vector::populateDistributeReduction(patterns, warpReduction, 1);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
@@ -672,6 +673,7 @@ struct TestVectorDistribution
       RewritePatternSet patterns(ctx);
       vector::populatePropagateWarpVectorDistributionPatterns(
           patterns, distributionFn, shuffleFn);
+      vector::populateWarpSimplificationPatterns(patterns);
       vector::populateDistributeReduction(patterns, warpReduction);
       (void)applyPatternsGreedily(getOperation(), std::move(patterns));
     }



More information about the Mlir-commits mailing list