[Mlir-commits] [mlir] [mlir][vector] Group re-order patterns together (PR #102856)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Aug 15 07:26:40 PDT 2024


================
@@ -2030,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
-               CombineContractABTranspose, CombineContractResultTranspose,
-               ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
----------------
banach-space wrote:

>  So is the expectation that you run the sink before you run contract conversion patterns.

See the summary :) 

> NOTES FOR DOWNSTREAM USERS
>
> In order to preserve the current functionality, please make sure to add
> 
> * populateSinkVectorOpsPatterns,
> 
> wherever you are using populateVectorReductionToContractPatterns.
> Also, rename populateSinkVectorBroadcastPatterns as
> populateSinkVectorOpsPatterns.

I checked MLIR and IREE and in both cases these were required _after_. In IREE, I run `ctest -R Codegen/SPIRV/` and `ctest -R Codegen/LLVMCPU` and both pass 100%. This is my diff:
```diff
diff --git a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
index 1824fb08bb..e2c8805d27 100644
--- a/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
@@ -146,7 +146,7 @@ struct EmulateNarrowTypePass final
     }

     RewritePatternSet sinkBroadcast(ctx);
-    vector::populateSinkVectorBroadcastPatterns(sinkBroadcast);
+    vector::populateSinkVectorOpsPatterns(sinkBroadcast);
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(sinkBroadcast)))) {
       getOperation()->emitOpError("failed in sinking of broadcasts");
diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
index 98814d1342..8aee5ba2c0 100644
--- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
@@ -409,6 +409,7 @@ void GenericVectorizationPass::runOnOperation() {
     vector::populateVectorTransferPermutationMapLoweringPatterns(
         vectorizationPatterns);
     vector::populateVectorReductionToContractPatterns(vectorizationPatterns);
+    vector::populateSinkVectorOpsPatterns(vectorizationPatterns);
   }
   if (foldCastIntoContract) {
     vector::populateFoldArithExtensionPatterns(vectorizationPatterns);
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
index aa72280e5f..f00af42845 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTensorCoreVectorization.cpp
@@ -96,6 +96,7 @@ public:
       vector::populateVectorTransferPermutationMapLoweringPatterns(
           contractionPatterns);
       vector::populateVectorReductionToContractPatterns(contractionPatterns);
+      vector::populateSinkVectorOpsPatterns(contractionPatterns);
       if (failed(applyPatternsAndFoldGreedily(
               funcOp, std::move(contractionPatterns)))) {
         return signalPassFailure();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
index 493894675c..c24da49689 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVInitialVectorLowering.cpp
@@ -314,6 +314,7 @@ public:
       // cancel them or embed into contract ops. Embedding in the flexible
       // contract ops will help to sustain the structure through various
       // transformations.
+      vector::populateSinkVectorOpsPatterns(patterns);
       vector::populateVectorReductionToContractPatterns(patterns);
       // Pull in patterns to canonicalize transfer ops.
       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/third_party/llvm-project b/third_party/llvm-project
index 6b7afaa9db..dc678f1dfe 160000
--- a/third_party/llvm-project
+++ b/third_party/llvm-project
@@ -1 +1 @@
-Subproject commit 6b7afaa9db8f904ebf0262774e38e54b36598782
+Subproject commit dc678f1dfe49cd4d0b9136ff2490a482ae91e786
```

@MaheshRavishankar anything else that I should check?

Btw, this re-grouping has been a bit tricky to verify 100% - there are no tests that would require these patterns to be run together (otherwise I wouldn't be able to move the tests around as I did).

https://github.com/llvm/llvm-project/pull/102856


More information about the Mlir-commits mailing list