[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Oct 18 07:29:42 PDT 2023


================
@@ -67,6 +67,112 @@ func.func @create_mask_transpose_to_transposed_create_mask(
 
 // -----
 
+// CHECK-LABEL: extract_from_create_mask
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+  %c2 = arith.constant 2 : index
+  %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+  // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[4]x[4]xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[1] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_all_false
+func.func @extract_from_create_mask_all_false(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+  %c2 = arith.constant 2 : index
+  %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+  // CHECK: arith.constant dense<false> : vector<[4]x[4]xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_leading_scalable
+//  CHECK-SAME: %[[DIM0:.*]]: index
+func.func @extract_from_create_mask_leading_scalable(%dim0: index) -> vector<8xi1> {
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %dim0 : vector<[4]x8xi1>
+  // CHECK: vector.create_mask %[[DIM0]] : vector<8xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[1] : vector<8xi1> from vector<[4]x8xi1>
+  return %extract : vector<8xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index) -> vector<6xi1> {
+  %c4 = arith.constant 4 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1>
+  // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1>
+  return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1>
----------------
banach-space wrote:

What happens if `%dim0` is a constant less than `6`?

https://github.com/llvm/llvm-project/pull/69456


More information about the Mlir-commits mailing list