[Mlir-commits] [mlir] e4e0bf6 - [mlir][Vector] Split transform.vector.lower_mask in 2 ops.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Apr 13 00:14:07 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-13T00:14:01-07:00
New Revision: e4e0bf63d0b3615b9a2481be6769a3c876763ec6
URL: https://github.com/llvm/llvm-project/commit/e4e0bf63d0b3615b9a2481be6769a3c876763ec6
DIFF: https://github.com/llvm/llvm-project/commit/e4e0bf63d0b3615b9a2481be6769a3c876763ec6.diff
LOG: [mlir][Vector] Split transform.vector.lower_mask in 2 ops.
This gives us better control to lower masked operations independently of the create mask operations.
It is often useful to maintain high-level mask information instead of lowering it too early to
too fine-grained form.
Differential Revision: https://reviews.llvm.org/D148162
Added:
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dcf3056d3e4ac..9b693554d7ed1 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -122,10 +122,31 @@ def LowerContractionOp : TransformWithPatternsOp<"vector.lower_contraction"> {
}];
}
-def LowerMaskOp : TransformWithPatternsOp<"vector.lower_mask"> {
+def LowerMasksOp : TransformWithPatternsOp<"vector.lower_masks"> {
let description = [{
- Indicates that the vector mask operations nested under the isolated from
- above op `target` should be lowered to finer-grained vector primitives.
+ Indicates that the vector.create_mask and vector.constant_mask operations
+ nested under the isolated from above op `target` should be lowered to
+ finer-grained vector primitives.
+
+ This is usually a late step that is run after bufferization as part of the
+ process of lowering to e.g. LLVM or NVVM.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$results);
+
+ let assemblyFormat = [{
+ $target
+ attr-dict
+ `:` functional-type($target, results)
+ }];
+}
+
+def LowerMaskedTransfersOp : TransformWithPatternsOp<"vector.lower_masked_transfers"> {
+ let description = [{
+ Indicates that masked vector.transfer and vector.gather operations nested
+ under the isolated from above op `target` should be lowered to finer-grained
+ vector primitives.
This is usually a late step that is run after bufferization as part of the
process of lowering to e.g. LLVM or NVVM.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index e7a0ba7892c85..679df80f8e82b 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -63,11 +63,19 @@ void transform::LowerContractionOp::populatePatterns(
}
//===----------------------------------------------------------------------===//
-// LowerMaskOp
+// LowerMasksOp
//===----------------------------------------------------------------------===//
-void transform::LowerMaskOp::populatePatterns(RewritePatternSet &patterns) {
+void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) {
populateVectorMaskOpLoweringPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// LowerMaskedTransfersOp
+//===----------------------------------------------------------------------===//
+
+void transform::LowerMaskedTransfersOp::populatePatterns(
+ RewritePatternSet &patterns) {
populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
}
diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
index 46faf99bf0ea1..58184b8467e43 100644
--- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
@@ -96,6 +96,36 @@ transform.sequence failures(propagate) {
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!pdl.operation) -> !pdl.operation
- transform.vector.lower_mask %f
+ transform.vector.lower_masks %f
+ : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_3d(
+func.func @transfer_read_3d(
+ %t: tensor<?x?x?xf32>, %arg0: index, %arg1: index, %arg2: index)
+ -> vector<2x1x7xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+ // CHECK: %[[mask:.*]] = vector.create_mask
+ // CHECK-NOT: vector.mask
+ // CHECK: vector.transfer_read {{.*}}, %[[mask]] {in_bounds = [true, true, true]}
+ // CHECK-SAME: : tensor<?x?x?xf32>, vector<2x1x7xf32>
+ %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
+ %1 = vector.mask %0 {
+ vector.transfer_read %t[%c0, %c0, %c0], %f0 {in_bounds = [true, true, true]}
+ : tensor<?x?x?xf32>, vector<2x1x7xf32>
+ } : vector<2x1x7xi1> -> vector<2x1x7xf32>
+
+ return %1: vector<2x1x7xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!pdl.operation) -> !pdl.operation
+
+ transform.vector.lower_masked_transfers %f
: (!pdl.operation) -> !pdl.operation
}
More information about the Mlir-commits
mailing list