[Mlir-commits] [mlir] 6b5ce2c - [mlir][transform] Expose vector patterns useful for cleaning up masked vectorization and add test to showcase composition

Nicolas Vasilache llvmlistbot at llvm.org
Fri Aug 4 04:25:20 PDT 2023


Author: Nicolas Vasilache
Date: 2023-08-04T11:19:27Z
New Revision: 6b5ce2cffe96d532a498b21a55f5b4c3a0570609

URL: https://github.com/llvm/llvm-project/commit/6b5ce2cffe96d532a498b21a55f5b4c3a0570609
DIFF: https://github.com/llvm/llvm-project/commit/6b5ce2cffe96d532a498b21a55f5b4c3a0570609.diff

LOG: [mlir][transform] Expose vector patterns useful for cleaning up masked vectorization and add test to showcase composition

Differential Revision: https://reviews.llvm.org/D157047

Added: 
    mlir/test/Dialect/Linalg/masked_vectorization.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 0b3a1a5e7c73e6..e3d27cb4c71690 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -118,8 +118,11 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_masked_transfers",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that masked vector.transfer and vector.gather operations should
-    be lowered to finer-grained vector primitives.
+    Apply opt-in patterns that lower vector.mask operations surrounding 
+    side-effecting ops:
+      - MaskedTransferReadOpPattern
+      - MaskedTransferWriteOpPattern
+      - MaskedGatherOpPattern
 
     This is usually a late step that is run after bufferization as part of the
     process of lowering to e.g. LLVM or NVVM.
@@ -313,4 +316,23 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.reduction_to_contract",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Apply opt-in patterns that convert reductions to contract:
+      - MultiReduceToContract
+      - CombineContractBroadcast
+      - CombineContractABTranspose
+      - CombineContractResultTranspose
+      - ReorderCastOpsOnBroadcast
+      - ReorderElementwiseOpsOnTranspose
+
+    These patterns have the effect of rewriting a vector.multi_reduce into a 
+    vector.contract.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 1866de6cb14703..8572b6df75bace 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -37,6 +37,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
   vector::populateFoldArithExtensionPatterns(patterns);
 }
 
+void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorReductionToContractPatterns(patterns);
+}
+
 void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorTransferDropUnitDimsPatterns(patterns);

diff  --git a/mlir/test/Dialect/Linalg/masked_vectorization.mlir b/mlir/test/Dialect/Linalg/masked_vectorization.mlir
new file mode 100644
index 00000000000000..0be10ad28e143a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/masked_vectorization.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+
+//      CHECK-LABEL: masked_matmul
+func.func @masked_matmul(%module: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+
+  //      CHECK: %[[MLHS:.*]] = vector.create_mask {{.*}} : vector<8x8xi1>
+  //      CHECK: %[[LHS:.*]] = vector.transfer_read %{{.*}}, %[[MLHS]] {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<8x8xf32> 
+  //      CHECK: %[[MRHS:.*]] = vector.create_mask {{.*}} : vector<8x8xi1> 
+  //      CHECK: %[[RHS:.*]] = vector.transfer_read %{{.*}}, %[[MRHS]] {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<8x8xf32>
+  //      CHECK: %[[MACC:.*]] = vector.create_mask {{.*}} : vector<8x8xi1>
+  //      CHECK: %[[ACC:.*]] = vector.transfer_read {{.*}}, %[[MACC]] {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<8x8xf32> 
+  //      CHECK: %[[MRES:.*]] = vector.create_mask {{.*}} : vector<8x8x8xi1>
+  //      CHECK: %[[RES:.*]] = vector.mask %[[MRES]] { vector.contract
+  // CHECK-SAME:   : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+  // CHECK-SAME:   : vector<8x8x8xi1> -> vector<8x8xf32>
+  //      CHECK: vector.transfer_write %[[RES]], %{{.*}}, %[[MACC]] {in_bounds = [true, true]} : vector<8x8xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>> 
+  linalg.matmul ins(%module, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
+  return
+}
+
+transform.sequence  failures(propagate) {
+^bb0(%module: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %module 
+    : (!transform.any_op) -> !transform.any_op
+  %tiled_linalg_op, %loops:3 = transform.structured.tile_to_scf_for %0[64, 128, 256] 
+    : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+  %tiled_linalg_op_0, %loops_1:3 = transform.structured.tile_to_scf_for %tiled_linalg_op[8, 8, 8] 
+    : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+  transform.structured.masked_vectorize %tiled_linalg_op_0 vector_sizes [8, 8, 8] 
+    : !transform.any_op
+
+  %func = transform.structured.match ops{["func.func"]} in %module 
+    : (!transform.any_op) -> !transform.any_op
+  apply_patterns to %func {
+    transform.apply_patterns.vector.lower_masked_transfers
+    transform.apply_patterns.vector.transfer_permutation_patterns
+    transform.apply_patterns.vector.reduction_to_contract
+  } : !transform.any_op
+}


        


More information about the Mlir-commits mailing list