[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Oct 18 07:29:42 PDT 2023
================
@@ -67,6 +67,112 @@ func.func @create_mask_transpose_to_transposed_create_mask(
// -----
+// CHECK-LABEL: extract_from_create_mask
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+ %c2 = arith.constant 2 : index
+ %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+ // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[4]x[4]xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[1] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_all_false
+func.func @extract_from_create_mask_all_false(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+ %c2 = arith.constant 2 : index
+ %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+ // CHECK: arith.constant dense<false> : vector<[4]x[4]xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_leading_scalable
+// CHECK-SAME: %[[DIM0:.*]]: index
+func.func @extract_from_create_mask_leading_scalable(%dim0: index) -> vector<8xi1> {
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3, %dim0 : vector<[4]x8xi1>
+ // CHECK: vector.create_mask %[[DIM0]] : vector<8xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[1] : vector<8xi1> from vector<[4]x8xi1>
+ return %extract : vector<8xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index) -> vector<6xi1> {
+ %c4 = arith.constant 4 : index
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1>
+ // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1>
+ return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1>
----------------
banach-space wrote:
What happens if `%dim0` is a constant less than `6`?
https://github.com/llvm/llvm-project/pull/69456
More information about the Mlir-commits
mailing list