[Mlir-commits] [mlir] e4e0bf6 - [mlir][Vector] Split transform.vector.lower_mask in 2 ops.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Apr 13 00:14:07 PDT 2023


Author: Nicolas Vasilache
Date: 2023-04-13T00:14:01-07:00
New Revision: e4e0bf63d0b3615b9a2481be6769a3c876763ec6

URL: https://github.com/llvm/llvm-project/commit/e4e0bf63d0b3615b9a2481be6769a3c876763ec6
DIFF: https://github.com/llvm/llvm-project/commit/e4e0bf63d0b3615b9a2481be6769a3c876763ec6.diff

LOG: [mlir][Vector] Split transform.vector.lower_mask in 2 ops.

This gives us better control to lower masked operations independently of the create mask operations.
It is often useful to maintain high-level mask information instead of lowering it too early to
too fine-grained form.

Differential Revision: https://reviews.llvm.org/D148162

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dcf3056d3e4ac..9b693554d7ed1 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -122,10 +122,31 @@ def LowerContractionOp : TransformWithPatternsOp<"vector.lower_contraction"> {
   }];
 }
 
-def LowerMaskOp : TransformWithPatternsOp<"vector.lower_mask"> {
+def LowerMasksOp : TransformWithPatternsOp<"vector.lower_masks"> {
   let description = [{
-    Indicates that the vector mask operations nested under the isolated from
-    above op `target` should be lowered to finer-grained vector primitives.
+    Indicates that the vector.create_mask and vector.constant_mask operations
+    nested under the isolated from above op `target` should be lowered to
+    finer-grained vector primitives.
+
+    This is usually a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$results);
+
+  let assemblyFormat = [{
+    $target
+    attr-dict
+    `:` functional-type($target, results)
+  }];
+}
+
+def LowerMaskedTransfersOp : TransformWithPatternsOp<"vector.lower_masked_transfers"> {
+  let description = [{
+    Indicates that masked vector.transfer and vector.gather operations nested
+    under the isolated from above op `target` should be lowered to finer-grained
+    vector primitives.
 
     This is usually a late step that is run after bufferization as part of the
     process of lowering to e.g. LLVM or NVVM.

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index e7a0ba7892c85..679df80f8e82b 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -63,11 +63,19 @@ void transform::LowerContractionOp::populatePatterns(
 }
 
 //===----------------------------------------------------------------------===//
-// LowerMaskOp
+// LowerMasksOp
 //===----------------------------------------------------------------------===//
 
-void transform::LowerMaskOp::populatePatterns(RewritePatternSet &patterns) {
+void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) {
   populateVectorMaskOpLoweringPatterns(patterns);
+}
+
+//===----------------------------------------------------------------------===//
+// LowerMaskedTransfersOp
+//===----------------------------------------------------------------------===//
+
+void transform::LowerMaskedTransfersOp::populatePatterns(
+    RewritePatternSet &patterns) {
   populateVectorMaskLoweringPatternsForSideEffectingOps(patterns);
 }
 

diff  --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
index 46faf99bf0ea1..58184b8467e43 100644
--- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
@@ -96,6 +96,36 @@ transform.sequence failures(propagate) {
   %f = transform.structured.match ops{["func.func"]} in %module_op 
     : (!pdl.operation) -> !pdl.operation
 
-  transform.vector.lower_mask %f
+  transform.vector.lower_masks %f
+      : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_read_3d(
+func.func @transfer_read_3d(
+    %t: tensor<?x?x?xf32>, %arg0: index, %arg1: index, %arg2: index)
+  -> vector<2x1x7xf32> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+  //      CHECK: %[[mask:.*]] = vector.create_mask
+  //  CHECK-NOT: vector.mask
+  //      CHECK: vector.transfer_read {{.*}}, %[[mask]] {in_bounds = [true, true, true]}
+  // CHECK-SAME:   : tensor<?x?x?xf32>, vector<2x1x7xf32>
+  %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
+  %1 = vector.mask %0 { 
+    vector.transfer_read %t[%c0, %c0, %c0], %f0 {in_bounds = [true, true, true]}
+      : tensor<?x?x?xf32>, vector<2x1x7xf32> 
+  } : vector<2x1x7xi1> -> vector<2x1x7xf32>
+
+  return %1: vector<2x1x7xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  %f = transform.structured.match ops{["func.func"]} in %module_op 
+    : (!pdl.operation) -> !pdl.operation
+
+  transform.vector.lower_masked_transfers %f
       : (!pdl.operation) -> !pdl.operation
 }


        


More information about the Mlir-commits mailing list