[Mlir-commits] [mlir] Revert "[mlir][vector] Add a pattern to fuse extract(constant_mask) (#81057)" (PR #83275)

Quinn Dawkins llvmlistbot at llvm.org
Wed Feb 28 07:25:44 PST 2024


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/83275

This reverts commit 5cdb8c0c8854d08ac7ca131ce3e8d78a32eb6716.

This pattern is producing incorrect IR. For example,

```mlir
func.func @extract_subvector_from_constant_mask() -> vector<16xi1> {
  %mask = vector.constant_mask [2, 3] : vector<16x16xi1>
  %extract = vector.extract %mask[8] : vector<16xi1> from vector<16x16xi1>
  return %extract : vector<16xi1>
}
```

Canonicalizes to

```mlir
func.func @extract_subvector_from_constant_mask() -> vector<16xi1> {
  %0 = vector.constant_mask [3] : vector<16xi1>
  return %0 : vector<16xi1>
}
```

Where it should be a zero mask because the extraction index (8) is greater than the constant mask size along that dim (2).

>From 7a88c33c6347446725a266304795a15d3e292a10 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 28 Feb 2024 10:22:55 -0500
Subject: [PATCH] Revert "[mlir][vector] Add a pattern to fuse
 extract(constant_mask) (#81057)"

This reverts commit 5cdb8c0c8854d08ac7ca131ce3e8d78a32eb6716.

This pattern is producing incorrect IR. For example,

```mlir
func.func @extract_subvector_from_constant_mask() -> vector<16xi1> {
  %mask = vector.constant_mask [2, 3] : vector<16x16xi1>
  %extract = vector.extract %mask[8] : vector<16xi1> from vector<16x16xi1>
  return %extract : vector<16xi1>
}
```

Canonicalizes to

```mlir
func.func @extract_subvector_from_constant_mask() -> vector<16xi1> {
  %0 = vector.constant_mask [3] : vector<16xi1>
  return %0 : vector<16xi1>
}
```

Where it should be a zero mask because the extraction index (8) is
greater than the constant mask size along that dim (2).
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 74 +---------------------
 mlir/test/Dialect/Vector/canonicalize.mlir | 73 ---------------------
 2 files changed, 1 insertion(+), 146 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c341a347ff6628..5be6a628904cdf7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2040,77 +2040,6 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
-// Patterns to rewrite ExtractOp(ConstantMaskOp)
-//
-// When the result of ExtractOp is a subvector of input, we can rewrite it as
-// a ConstantMaskOp with subvector ranks.
-//
-// ExtractOp(ConstantMaskOp) -> ConstantMaskOp
-//
-// When the result of ExtractOp is a scalar, we can get the scalar value
-// directly.
-//
-// ExtractOp(ConstantMaskOp) -> ConstantOp
-class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    auto constantMaskOp =
-        extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
-    if (!constantMaskOp)
-      return failure();
-
-    // All indices must be static.
-    ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
-    unsigned dynamicPosCount =
-        llvm::count_if(extractOpPos, ShapedType::isDynamic);
-    // If there is any dynamic position in ExtractOp, we cannot determine the
-    // scalar value.
-    if (dynamicPosCount)
-      return failure();
-
-    ArrayRef<Attribute> maskDimSizes =
-        constantMaskOp.getMaskDimSizes().getValue();
-    Type resultTy = extractOp.getResult().getType();
-    if (resultTy.isa<mlir::VectorType>()) {
-      auto resultVectorTy = resultTy.cast<mlir::VectorType>();
-      int64_t resultRank = resultVectorTy.getRank();
-      int64_t n = maskDimSizes.size();
-      std::vector<int64_t> indices;
-      for (auto i = n - resultRank; i < n; ++i)
-        indices.push_back(cast<IntegerAttr>(maskDimSizes[i]).getInt());
-
-      rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
-          extractOp, resultVectorTy,
-          vector::getVectorSubscriptAttr(rewriter, indices));
-
-      return success();
-    } else if (resultTy.isa<mlir::IntegerType>()) {
-      // ConstantMaskOp creates and returns a vector mask where elements of the
-      // result vector are set to ‘0’ or ‘1’, based on whether the element
-      // indices are contained within a hyper-rectangular region.
-      // We go through ExtractOp static positions to determine the position is
-      // within the hyper-rectangular region or not.
-      Type boolType = rewriter.getI1Type();
-      IntegerAttr setAttr = IntegerAttr::get(boolType, 1);
-      for (size_t i = 0, end = extractOpPos.size(); i < end; ++i) {
-        if (cast<IntegerAttr>(maskDimSizes[i]).getInt() <= extractOpPos[i]) {
-          setAttr = IntegerAttr::get(boolType, 0);
-          break;
-        }
-      }
-
-      rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, boolType,
-                                                     setAttr);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2137,8 +2066,7 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask,
-              ExtractOpFromConstantMask>(context);
+              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
 }
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index dc486181ebe9605..e6f045e12e51973 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,76 +2567,3 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
-
-// -----
-
-// CHECK-LABEL: func.func @extract_true_from_constant_mask() -> i1 {
-func.func @extract_true_from_constant_mask() -> i1 {
-// CHECK:      %[[TRUE:.*]] = arith.constant true
-// CHECK-NEXT: return %[[TRUE]] : i1
-  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
-  %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
-  return %extract : i1
-}
-
-// -----
-
-// CHECK-LABEL: func.func @extract_false_from_constant_mask() -> i1 {
-func.func @extract_false_from_constant_mask() -> i1 {
-// CHECK:      %[[FALSE:.*]] = arith.constant false
-// CHECK-NEXT: return %[[FALSE]] : i1
-  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
-  %extract = vector.extract %mask[1, 2, 2] : i1 from vector<4x4x4xi1>
-  return %extract : i1
-}
-
-// -----
-
-// CHECK-LABEL: func.func @extract_from_create_mask() -> i1 {
-func.func @extract_from_create_mask() -> i1 {
-// CHECK:      %[[TRUE:.*]] = arith.constant true
-// CHECK-NEXT: return %[[TRUE]] : i1
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %mask = vector.create_mask %c2, %c2, %c3 : vector<4x4x4xi1>
-  %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
-  return %extract : i1
-}
-
-// -----
-
-// CHECK-LABEL: func.func @extract_subvector_from_constant_mask() ->
-// CHECK-SAME:  vector<6xi1> {
-func.func @extract_subvector_from_constant_mask() -> vector<6xi1> {
-// CHECK:      %[[S0:.*]] = vector.constant_mask [4] : vector<6xi1>
-// CHECK-NEXT: return %[[S0]] : vector<6xi1>
-  %mask = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
-  %extract = vector.extract %mask[1, 2] : vector<6xi1> from vector<4x5x6xi1>
-  return %extract : vector<6xi1>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @extract_scalar_with_dynamic_positions(
-// CHECK-SAME:    %[[INDEX:.*]]: index) -> i1 {
-func.func @extract_scalar_with_dynamic_positions(%index: index) -> i1 {
-// CHECK:       %[[S0:.*]] = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
-// CHECK-NEXT:  %[[S1:.*]] = vector.extract %[[S0]][1, 1, %[[INDEX]]] : i1 from vector<4x4x4xi1>
-// CHECK-NEXT:  return %[[S1]] : i1
-  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
-  %extract = vector.extract %mask[1, 1, %index] : i1 from vector<4x4x4xi1>
-  return %extract : i1
-}
-
-// -----
-
-// CHECK-LABEL: func.func @extract_subvector_with_dynamic_positions
-// CHECK-SAME:    %[[INDEX:.*]]: index) -> vector<6xi1> {
-func.func @extract_subvector_with_dynamic_positions(%index: index) -> vector<6xi1> {
-// CHECK:      %[[S0:.*]] = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
-// CHECK-NEXT: %[[S1:.*]] = vector.extract %[[S0]][1, %[[INDEX]]] : vector<6xi1> from vector<4x5x6xi1>
-// CHECK-NEXT: return %[[S1]] : vector<6xi1>
-  %mask = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
-  %extract = vector.extract %mask[1, %index] : vector<6xi1> from vector<4x5x6xi1>
-  return %extract : vector<6xi1>
-}



More information about the Mlir-commits mailing list