[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Oct 18 05:34:53 PDT 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/69456

This allows folding extracts from `vector.create_mask` ops, that have a known value. Currently, there's no fold for this, but you get the same effect from the unrolling in LowerVectorMask (part of -convert-vector-to-llvm), then folds after that. However, for a future patch, this simplification needs to be done before lowering to LLVM, hence the need for this fold.

E.g.:

```
%0 = vector.create_mask %c1, %dimA, %dimB : vector<1x[4]x[4]xi1>
%1 = vector.extract %mask[0] : vector<[4]x[4]xi1>
```
->
```
%0 = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
```

>From 62602c2a441fb9cd208ba7f6693fe39a5d1843e5 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 18 Oct 2023 12:10:03 +0000
Subject: [PATCH] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) ->
 CreateMask`

This allows folding extracts from `vector.create_mask` ops, that have a
known value. Currently, there's no fold for this, but you get the same
effect from the unrolling in LowerVectorMask (part of
-convert-vector-to-llvm), then folds after that. However, for a future
patch, this simplification needs to be done before lowering to LLVM,
hence the need for this fold.

E.g.:

```
%0 = vector.create_mask %c1, %dimA, %dimB : vector<1x[4]x[4]xi1>
%1 = vector.extract %mask[0] : vector<[4]x[4]xi1>
```
->
```
%0 = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
```
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  58 ++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 106 +++++++++++++++++++++
 2 files changed, 163 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 68a5cf209f2fb49..82e05f56bd4ed6a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1983,6 +1983,62 @@ class ExtractOpNonSplatConstantFolder final
   }
 };
 
+// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
+class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+
+    if (!createMaskOp)
+      return failure();
+
+    ArrayRef<int64_t> position = extractOp.getStaticPosition();
+    auto maskOperands = createMaskOp.getOperands();
+    VectorType maskType = createMaskOp.getVectorType();
+    VectorType::Builder newMaskType(maskType);
+
+    bool allFalseMask = false;
+    for (auto [i, pos] : llvm::enumerate(position)) {
+      Value operand = maskOperands[i];
+      auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
+      if (!constantOp)
+        return failure();
+
+      int64_t createMaskBound =
+          llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+
+      if (pos == ShapedType::kDynamic) {
+        // Extractions must be in-bounds. So if the corresponding `create_mask`
+        // size is 0 or the size of the dim, we know this dim is false or true.
+        if (createMaskBound == 0)
+          allFalseMask = true;
+        else if (createMaskBound < maskType.getDimSize(i))
+          return failure(); // unknown value.
+      } else {
+        // If any position is outside the range from the `create_mask`, then the
+        // extracted mask will be all false.
+        allFalseMask |= pos >= createMaskBound;
+      }
+
+      newMaskType.dropDim(0);
+    }
+
+    if (!allFalseMask) {
+      rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+          extractOp, VectorType(newMaskType),
+          maskOperands.drop_front(position.size()));
+    } else {
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          extractOp, DenseElementsAttr::get(VectorType(newMaskType), false));
+    }
+    return success();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2009,7 +2065,7 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast>(context);
+              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
 }
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 924886c50030967..dd2c78eb44e9f9e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,6 +67,112 @@ func.func @create_mask_transpose_to_transposed_create_mask(
 
 // -----
 
+// CHECK-LABEL: extract_from_create_mask
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+  %c2 = arith.constant 2 : index
+  %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+  // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[4]x[4]xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[1] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_all_false
+func.func @extract_from_create_mask_all_false(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+  %c2 = arith.constant 2 : index
+  %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+  // CHECK: arith.constant dense<false> : vector<[4]x[4]xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_leading_scalable
+//  CHECK-SAME: %[[DIM0:.*]]: index
+func.func @extract_from_create_mask_leading_scalable(%dim0: index) -> vector<8xi1> {
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %dim0 : vector<[4]x8xi1>
+  // CHECK: vector.create_mask %[[DIM0]] : vector<8xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[1] : vector<8xi1> from vector<[4]x8xi1>
+  return %extract : vector<8xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index) -> vector<6xi1> {
+  %c4 = arith.constant 4 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1>
+  // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1>
+  return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1>
+  // CHECK: arith.constant dense<false> : vector<6xi1>
+  // CHECK-NOT: vector.extract
+  %extract = vector.extract %mask[0, %index] : vector<6xi1> from vector<1x4x6xi1>
+  return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position_unknown
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position_unknown(%dim0: index, %index: index) -> vector<6xi1> {
+  %c2 = arith.constant 2 : index
+  %mask = vector.create_mask %c2, %dim0 : vector<4x6xi1>
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[DIM0]] : vector<4x6xi1>
+  // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]]] : vector<6xi1> from vector<4x6xi1>
+  %extract = vector.extract %mask[%index] : vector<6xi1> from vector<4x6xi1>
+  return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_mixed_position_unknown
+//  CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_mixed_position_unknown(%dim0: index, %index0: index) -> vector<4xi1> {
+  %c2 = arith.constant 2 : index
+  %mask = vector.create_mask %c2, %c2, %dim0 : vector<2x4x4xi1>
+  // CHECK: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[C2]], %[[DIM0]] : vector<2x4x4xi1>
+  // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]]] : vector<4xi1> from vector<2x4x4xi1>
+  %extract = vector.extract %mask[1, %index0] : vector<4xi1> from vector<2x4x4xi1>
+  return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_non_constant_create_mask
+//  CHECK-SAME: %[[DIM0:.*]]: index
+func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1> {
+  %mask = vector.create_mask %dim0, %dim0 : vector<[2]x[2]xi1>
+  // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM0]] : vector<[2]x[2]xi1>
+  // CHECK-NEXT: vector.extract %[[MASK]][0] : vector<[2]xi1> from vector<[2]x[2]xi1>
+  %extract = vector.extract %mask[0] : vector<[2]xi1> from vector<[2]x[2]xi1>
+  return %extract : vector<[2]xi1>
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>



More information about the Mlir-commits mailing list