[Mlir-commits] [mlir] 6b5ce2c - [mlir][transform] Expose vector patterns useful for cleaning up masked vectorization and add test to showcase composition
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Aug 4 04:25:20 PDT 2023
Author: Nicolas Vasilache
Date: 2023-08-04T11:19:27Z
New Revision: 6b5ce2cffe96d532a498b21a55f5b4c3a0570609
URL: https://github.com/llvm/llvm-project/commit/6b5ce2cffe96d532a498b21a55f5b4c3a0570609
DIFF: https://github.com/llvm/llvm-project/commit/6b5ce2cffe96d532a498b21a55f5b4c3a0570609.diff
LOG: [mlir][transform] Expose vector patterns useful for cleaning up masked vectorization and add test to showcase composition
Differential Revision: https://reviews.llvm.org/D157047
Added:
mlir/test/Dialect/Linalg/masked_vectorization.mlir
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 0b3a1a5e7c73e6..e3d27cb4c71690 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -118,8 +118,11 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_masked_transfers",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Indicates that masked vector.transfer and vector.gather operations should
- be lowered to finer-grained vector primitives.
+ Apply opt-in patterns that lower vector.mask operations surrounding
+ side-effecting ops:
+ - MaskedTransferReadOpPattern
+ - MaskedTransferWriteOpPattern
+ - MaskedGatherOpPattern
This is usually a late step that is run after bufferization as part of the
process of lowering to e.g. LLVM or NVVM.
@@ -313,4 +316,23 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.reduction_to_contract",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Apply opt-in patterns that convert reductions to contract:
+ - MultiReduceToContract
+ - CombineContractBroadcast
+ - CombineContractABTranspose
+ - CombineContractResultTranspose
+ - ReorderCastOpsOnBroadcast
+ - ReorderElementwiseOpsOnTranspose
+
+ These patterns have the effect of rewriting a vector.multi_reduce into a
+ vector.contract.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 1866de6cb14703..8572b6df75bace 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -37,6 +37,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
vector::populateFoldArithExtensionPatterns(patterns);
}
+void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorReductionToContractPatterns(patterns);
+}
+
void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorTransferDropUnitDimsPatterns(patterns);
diff --git a/mlir/test/Dialect/Linalg/masked_vectorization.mlir b/mlir/test/Dialect/Linalg/masked_vectorization.mlir
new file mode 100644
index 00000000000000..0be10ad28e143a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/masked_vectorization.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+
+// CHECK-LABEL: masked_matmul
+func.func @masked_matmul(%module: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
+
+ // CHECK: %[[MLHS:.*]] = vector.create_mask {{.*}} : vector<8x8xi1>
+ // CHECK: %[[LHS:.*]] = vector.transfer_read %{{.*}}, %[[MLHS]] {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<8x8xf32>
+ // CHECK: %[[MRHS:.*]] = vector.create_mask {{.*}} : vector<8x8xi1>
+ // CHECK: %[[RHS:.*]] = vector.transfer_read %{{.*}}, %[[MRHS]] {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<8x8xf32>
+ // CHECK: %[[MACC:.*]] = vector.create_mask {{.*}} : vector<8x8xi1>
+ // CHECK: %[[ACC:.*]] = vector.transfer_read {{.*}}, %[[MACC]] {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<8x8xf32>
+ // CHECK: %[[MRES:.*]] = vector.create_mask {{.*}} : vector<8x8x8xi1>
+ // CHECK: %[[RES:.*]] = vector.mask %[[MRES]] { vector.contract
+ // CHECK-SAME: : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+ // CHECK-SAME: : vector<8x8x8xi1> -> vector<8x8xf32>
+ // CHECK: vector.transfer_write %[[RES]], %{{.*}}, %[[MACC]] {in_bounds = [true, true]} : vector<8x8xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>
+ linalg.matmul ins(%module, %arg1 : memref<?x?xf32>, memref<?x?xf32>) outs(%arg2 : memref<?x?xf32>)
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb0(%module: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %module
+ : (!transform.any_op) -> !transform.any_op
+ %tiled_linalg_op, %loops:3 = transform.structured.tile_to_scf_for %0[64, 128, 256]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ %tiled_linalg_op_0, %loops_1:3 = transform.structured.tile_to_scf_for %tiled_linalg_op[8, 8, 8]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.structured.masked_vectorize %tiled_linalg_op_0 vector_sizes [8, 8, 8]
+ : !transform.any_op
+
+ %func = transform.structured.match ops{["func.func"]} in %module
+ : (!transform.any_op) -> !transform.any_op
+ apply_patterns to %func {
+ transform.apply_patterns.vector.lower_masked_transfers
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ transform.apply_patterns.vector.reduction_to_contract
+ } : !transform.any_op
+}
More information about the Mlir-commits
mailing list