[Mlir-commits] [mlir] [mlir][vector] Add a pattern to fuse extract(constant_mask) (PR #81057)

Hsiangkai Wang llvmlistbot at llvm.org
Mon Feb 26 07:38:49 PST 2024


================
@@ -2567,3 +2567,74 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @extract_true_from_constant_mask() -> i1 {
+func.func @extract_true_from_constant_mask() -> i1 {
+// CHECK:      %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: return %[[TRUE]] : i1
+  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+  %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_false_from_constant_mask() -> i1 {
+func.func @extract_false_from_constant_mask() -> i1 {
+// CHECK:      %[[FALSE:.*]] = arith.constant false
+// CHECK-NEXT: return %[[FALSE]] : i1
+  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+  %extract = vector.extract %mask[1, 2, 2] : i1 from vector<4x4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_create_mask() -> i1 {
+func.func @extract_from_create_mask() -> i1 {
+// CHECK:      %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: return %[[TRUE]] : i1
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c2, %c2, %c3 : vector<4x4x4xi1>
+  %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_subvector_from_constant_mask() ->
+// CHECK-SAME:  vector<6xi1> {
+func.func @extract_subvector_from_constant_mask() -> vector<6xi1> {
+// CHECK:      %[[S0:.*]] = vector.constant_mask [4] : vector<6xi1>
+// CHECK-NEXT: return %[[S0]] : vector<6xi1>
+  %mask = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
+  %extract = vector.extract %mask[1, 2] : vector<6xi1> from vector<4x5x6xi1>
+  return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_scalar_with_dynamic_positions(
+// CHECK-SAME:    %[[INDEX:.*]]: index) -> i1 {
+func.func @extract_scalar_with_dynamic_positions(%index: index) -> i1 {
+// CHECK:       %[[S0:.*]] = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+// CHECK-NEXT:  %[[S1:.*]] = vector.extract %[[S0]][1, 1, %[[INDEX]]] : i1 from vector<4x4x4xi1>
+// CHECK-NEXT:  return %[[S1]] : i1
+  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+  %extract = vector.extract %mask[1, 1, %index] : i1 from vector<4x4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_subvector_with_dynamic_positions
+func.func @extract_subvector_with_dynamic_positions(%index: index) -> vector<6xi1> {
+// CHECK:      %[[S0:.*]] = vector.constant_mask [4] : vector<6xi1>
+// CHECK-NEXT: return %[[S0]] : vector<6xi1>
+  %mask = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
+  %extract = vector.extract %mask[1, %index] : vector<6xi1> from vector<4x5x6xi1>
----------------
Hsiangkai wrote:

You are right. Fixed.

https://github.com/llvm/llvm-project/pull/81057


More information about the Mlir-commits mailing list