[Mlir-commits] [mlir] [mlir][vector] Add mask elimination transform (PR #99314)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Aug 6 10:20:29 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/99314

>From 8a861c47f96c05a8d93d72d05eb2986104098f16 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 24 Jun 2024 14:42:18 +0000
Subject: [PATCH 1/6] [mlir][vector] Add mask elimination transform

This adds a new transform `eliminateVectorMasks()` which aims at
removing scalable `vector.create_masks` that will be all-true at
runtime. It attempts to do this by simply pattern-matching the mask
operands (similar to some canonicalizations), if that does not lead to
an answer (is all-true? yes/no), then value bounds analysis will be used
to find the lower bound of the unknown operands. If the lower bound is
>= to the corresponding mask vector type dim, then that dimension of the
mask is all true.

Note: Eliminating create_masks here means replacing them with all-true
constants (which will then lead to the masks folding away).
---
 .../Vector/Transforms/VectorTransforms.h      |  17 +++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/VectorMaskElimination.cpp      | 117 +++++++++++++++
 mlir/test/Dialect/Vector/eliminate-masks.mlir | 138 ++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  34 +++++
 5 files changed, 307 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
 create mode 100644 mlir/test/Dialect/Vector/eliminate-masks.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 1f7d6411cd5a46..847f333d6a9310 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
 class MLIRContext;
@@ -115,6 +116,22 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                  MaskingOpInterface maskingOp,
                                  RewriterBase &rewriter);
 
+/// Structure to hold the range [vscaleMin, vscaleMax] `vector.vscale` can take.
+struct VscaleRange {
+  unsigned vscaleMin;
+  unsigned vscaleMax;
+};
+
+/// Attempts to eliminate redundant vector masks by replacing them with all-true
+/// constants at the top of the function (which results in the masks folding
+/// away). Note: Currently, this only runs for vector.create_mask ops and
+/// requires `vscaleRange`. If `vscaleRange` is not provided this transform does
+/// nothing. This is because these redundant masks are much more likely for
+/// scalable code which requires memref/tensor dynamic sizes, whereas fixed-size
+/// code has static sizes, so simpler folds remove the masks.
+void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
+                          std::optional<VscaleRange> vscaleRange = {});
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 723b2f62d65d4f..2639a67e1c8b31 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
   VectorUnroll.cpp
+  VectorMaskElimination.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
new file mode 100644
index 00000000000000..abec8c75b8fc91
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
@@ -0,0 +1,117 @@
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+namespace {
+
+/// If `value` is a constant multiple of `vector.vscale` return the multiplier.
+std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
+  if (value.getDefiningOp<vector::VectorScaleOp>())
+    return 1;
+  auto mul = value.getDefiningOp<arith::MulIOp>();
+  if (!mul)
+    return {};
+  auto lhs = mul.getLhs();
+  auto rhs = mul.getRhs();
+  if (lhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(rhs);
+  if (rhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(lhs);
+  return {};
+}
+
+/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
+/// All-true masks can then be eliminated by simple folds.
+LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
+                                         vector::CreateMaskOp createMaskOp,
+                                         VscaleRange vscaleRange) {
+  auto maskType = createMaskOp.getVectorType();
+  auto maskTypeDimScalableFlags = maskType.getScalableDims();
+  auto maskTypeDimSizes = maskType.getShape();
+
+  struct UnknownMaskDim {
+    size_t position;
+    Value dimSize;
+  };
+
+  // Check for any dims that could be (partially) false before doing the more
+  // expensive value bounds computations.
+  SmallVector<UnknownMaskDim> unknownDims;
+  for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      // Mask not all-true for this dim.
+      if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
+        return failure();
+    } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
+      // Mask not all-true for this dim.
+      if (vscaleMultiplier < maskTypeDimSizes[i])
+        return failure();
+    } else {
+      // Unknown (without further analysis).
+      unknownDims.push_back(UnknownMaskDim{i, dimSize});
+    }
+  }
+
+  for (auto [i, dimSize] : unknownDims) {
+    // Compute the lower bound for the unknown dimension (i.e. the smallest
+    // value it could be).
+    auto lowerBound =
+        vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+            dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
+            presburger::BoundType::LB);
+    if (failed(lowerBound))
+      return failure();
+    auto boundSize = lowerBound->getSize();
+    if (failed(boundSize))
+      return failure();
+    if (boundSize->scalable) {
+      // If the lower bound is scalable and >= to the mask dim size then this
+      // dim is all-true.
+      if (boundSize->baseSize < maskTypeDimSizes[i])
+        return failure();
+    } else {
+      // If the lower bound is a constant and >= to the _fixed-size_ mask dim
+      // size then this dim is all-true.
+      if (maskTypeDimScalableFlags[i])
+        return failure();
+      if (boundSize->baseSize < maskTypeDimSizes[i])
+        return failure();
+    }
+  }
+
+  // Replace createMaskOp with an all-true constant. This should result in the
+  // mask being removed in most cases (as xfer ops + vector.mask have folds to
+  // remove all-true masks).
+  auto allTrue = rewriter.create<arith::ConstantOp>(
+      createMaskOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
+  rewriter.replaceAllUsesWith(createMaskOp, allTrue);
+  return success();
+}
+
+} // namespace
+
+namespace mlir::vector {
+
+void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
+                          std::optional<VscaleRange> vscaleRange) {
+  // TODO: Support fixed-size case. This is less likely to be useful as for
+  // fixed-size code dimensions are all static so masks tend to fold away.
+  if (!vscaleRange)
+    return;
+
+  OpBuilder::InsertionGuard g(rewriter);
+  SmallVector<vector::CreateMaskOp> worklist;
+  function.walk([&](vector::CreateMaskOp createMaskOp) {
+    worklist.push_back(createMaskOp);
+  });
+  rewriter.setInsertionPointToStart(&function.front());
+  for (auto mask : worklist)
+    (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
+}
+
+} // namespace mlir::vector
diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir
new file mode 100644
index 00000000000000..99c9a60a09facb
--- /dev/null
+++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir
@@ -0,0 +1,138 @@
+// RUN: mlir-opt %s -split-input-file -test-eliminate-vector-masks  | FileCheck %s
+
+// This tests a general pattern the vectorizer tends to emit.
+
+// CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts
+// CHECK: %[[ALL_TRUE_MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]]
+// CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]]
+func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor<1x1000xf32>) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c1000 = arith.constant 1000 : index
+  %c0_f32 = arith.constant 0.0 : f32
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  %extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32>
+  %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> {
+    // 1. Extract a slice.
+    %extracted_slice_1 = tensor.extract_slice %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
+
+    // 2. Create a mask for the slice.
+    %dim_1 = tensor.dim %extracted_slice_1, %c0 : tensor<?xf32>
+    %mask = vector.create_mask %dim_1 : vector<[4]xi1>
+
+    // 3. Read the slice and do some computation.
+    %vec = vector.transfer_read %extracted_slice_1[%c0], %c0_f32, %mask {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32>
+    %new_vec = "test.some_computation"(%vec) : (vector<[4]xf32>) -> (vector<[4]xf32>)
+
+    // 4. Write the new value.
+    %write = vector.transfer_write %new_vec, %extracted_slice_1[%c0], %mask {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32>
+
+    // 5. Insert and yield the new tensor value.
+    %result = tensor.insert_slice %write into %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
+    scf.yield %result : tensor<1x?xf32>
+  }
+  "test.some_use"(%output_tensor) : (tensor<1x?xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_extract_slice_size_shrink
+// CHECK-NOT: arith.constant dense<true> : vector<[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<[4]xi1>) -> ()
+func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c1000 = arith.constant 1000 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  %extracted_slice = tensor.extract_slice %tensor[0] [%c4_vscale] [1] : tensor<1000xf32> to tensor<?xf32>
+  %slice = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice) -> tensor<?xf32> {
+    // This mask cannot be eliminated even though looking at the above operations
+    // it appears `tensor.dim` will always be c4_vscale (so the mask all-true).
+    %dim = tensor.dim %arg, %c0 : tensor<?xf32>
+    %mask = vector.create_mask %dim : vector<[4]xi1>
+    "test.some_use"(%mask) : (vector<[4]xi1>) -> ()
+    // !!! Here the size of the mask could shrink in the next iteration.
+    %next_num_els = affine.min  affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
+    %new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_els] [1] : tensor<1000xf32> to tensor<?xf32>
+    scf.yield %new_extracted_slice : tensor<?xf32>
+  }
+  "test.some_use"(%slice) : (tensor<?xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_constant_dim_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
+func.func @negative_constant_dim_not_all_true()
+{
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  %mask = vector.create_mask %c1, %c4_vscale : vector<2x[4]xi1>
+  "test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_constant_vscale_multiple_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
+func.func @negative_constant_vscale_multiple_not_all_true() {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %vscale = vector.vscale
+  %c3_vscale = arith.muli %vscale, %c3 : index
+  %mask = vector.create_mask %c2, %c3_vscale : vector<2x[4]xi1>
+  "test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_value_bounds_fixed_dim_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
+func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  // This is _very_ simple but since addi is not a constant value bounds will
+  // be used to resolve it.
+  %dim = tensor.dim %tensor, %c0 : tensor<2x?xf32>
+  %mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1>
+  "test.some_use"(%mask) : (vector<3x[4]xi1>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @negative_value_bounds_scalable_dim_not_all_true
+// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask
+// CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
+func.func @negative_value_bounds_scalable_dim_not_all_true(%tensor: tensor<2x100xf32>) {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %vscale = vector.vscale
+  %c3_vscale = arith.muli %vscale, %c3 : index
+  %slice = tensor.extract_slice %tensor[0, 0] [2, %c3_vscale] [1, 1] : tensor<2x100xf32> to tensor<2x?xf32>
+  // Another simple example, but value bounds will be used to resolve the tensor.dim.
+  %dim = tensor.dim %slice, %c1 : tensor<2x?xf32>
+  %mask = vector.create_mask %c3, %dim : vector<3x[4]xi1>
+  "test.some_use"(%mask) : (vector<3x[4]xi1>) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 592e24af94d677..e7c05b1b6edeee 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -873,6 +873,38 @@ struct TestVectorLinearize final
       return signalPassFailure();
   }
 };
+
+struct TestEliminateVectorMasks
+    : public PassWrapper<TestEliminateVectorMasks,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
+
+  TestEliminateVectorMasks() = default;
+  TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
+      : PassWrapper(pass) {}
+
+  Option<unsigned> vscaleMin{
+      *this, "vscale-min",
+      llvm::cl::desc(
+          "Minimum value `vector.vscale` can possibly be at runtime."),
+      llvm::cl::init(1)};
+
+  Option<unsigned> vscaleMax{
+      *this, "vscale-max",
+      llvm::cl::desc(
+          "Maximum value `vector.vscale` can possibly be at runtime."),
+      llvm::cl::init(16)};
+
+  StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
+  StringRef getDescription() const final {
+    return "Test eliminating vector masks";
+  }
+  void runOnOperation() override {
+    IRRewriter rewriter(&getContext());
+    eliminateVectorMasks(rewriter, getOperation(),
+                         VscaleRange{vscaleMin, vscaleMax});
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -919,6 +951,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorEmulateMaskedLoadStore>();
 
   PassRegistration<TestVectorLinearize>();
+
+  PassRegistration<TestEliminateVectorMasks>();
 }
 } // namespace test
 } // namespace mlir

>From 6e1457310325f58276d11e17bac670609d56856d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 19 Jul 2024 09:59:00 +0000
Subject: [PATCH 2/6] Fixups

---
 .../Dialect/Vector/Transforms/VectorTransforms.h     |  2 +-
 .../Vector/Transforms/VectorMaskElimination.cpp      | 12 ++++++++++++
 mlir/test/Dialect/Vector/eliminate-masks.mlir        | 12 ++++++------
 .../test/lib/Dialect/Vector/TestVectorTransforms.cpp |  9 ++-------
 4 files changed, 21 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 847f333d6a9310..e815e026305fab 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -116,7 +116,7 @@ castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                  MaskingOpInterface maskingOp,
                                  RewriterBase &rewriter);
 
-/// Structure to hold the range [vscaleMin, vscaleMax] `vector.vscale` can take.
+// Structure to hold the range of `vector.vscale`.
 struct VscaleRange {
   unsigned vscaleMin;
   unsigned vscaleMax;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
index abec8c75b8fc91..486784a9cf102b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
@@ -1,3 +1,11 @@
+//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
@@ -105,10 +113,14 @@ void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function,
     return;
 
   OpBuilder::InsertionGuard g(rewriter);
+
+  // Build worklist so we can safely insert new ops in
+  // `resolveAllTrueCreateMaskOp()`.
   SmallVector<vector::CreateMaskOp> worklist;
   function.walk([&](vector::CreateMaskOp createMaskOp) {
     worklist.push_back(createMaskOp);
   });
+
   rewriter.setInsertionPointToStart(&function.front());
   for (auto mask : worklist)
     (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir
index 99c9a60a09facb..51564d0ca1d4f1 100644
--- a/mlir/test/Dialect/Vector/eliminate-masks.mlir
+++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir
@@ -16,7 +16,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
   %extracted_slice_0 = tensor.extract_slice %tensor[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x1000xf32> to tensor<1x?xf32>
   %output_tensor = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice_0) -> tensor<1x?xf32> {
     // 1. Extract a slice.
-    %extracted_slice_1 = tensor.extract_slice %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
+    %extracted_slice_1 = tensor.extract_slice %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
 
     // 2. Create a mask for the slice.
     %dim_1 = tensor.dim %extracted_slice_1, %c0 : tensor<?xf32>
@@ -30,7 +30,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
     %write = vector.transfer_write %new_vec, %extracted_slice_1[%c0], %mask {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32>
 
     // 5. Insert and yield the new tensor value.
-    %result = tensor.insert_slice %write into %arg[0, 0] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
+    %result = tensor.insert_slice %write into %arg[0, %i] [1, %c4_vscale] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
     scf.yield %result : tensor<1x?xf32>
   }
   "test.some_use"(%output_tensor) : (tensor<1x?xf32>) -> ()
@@ -57,8 +57,8 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
     %mask = vector.create_mask %dim : vector<[4]xi1>
     "test.some_use"(%mask) : (vector<[4]xi1>) -> ()
     // !!! Here the size of the mask could shrink in the next iteration.
-    %next_num_els = affine.min  affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
-    %new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_els] [1] : tensor<1000xf32> to tensor<?xf32>
+    %next_num_elts = affine.min  affine_map<(d0)[s0] -> (-d0 + 1000, s0)>(%i)[%c4_vscale]
+    %new_extracted_slice = tensor.extract_slice %tensor[%c4_vscale] [%next_num_elts] [1] : tensor<1000xf32> to tensor<?xf32>
     scf.yield %new_extracted_slice : tensor<?xf32>
   }
   "test.some_use"(%slice) : (tensor<?xf32>) -> ()
@@ -110,8 +110,8 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
   %c4 = arith.constant 4 : index
   %vscale = vector.vscale
   %c4_vscale = arith.muli %vscale, %c4 : index
-  // This is _very_ simple but since addi is not a constant value bounds will
-  // be used to resolve it.
+  // This is _very_ simple but since tensor.dim is not a constant value bounds
+  // will be used to resolve it.
   %dim = tensor.dim %tensor, %c0 : tensor<2x?xf32>
   %mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1>
   "test.some_use"(%mask) : (vector<3x[4]xi1>) -> ()
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e7c05b1b6edeee..29c763b622e877 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -884,15 +884,10 @@ struct TestEliminateVectorMasks
       : PassWrapper(pass) {}
 
   Option<unsigned> vscaleMin{
-      *this, "vscale-min",
-      llvm::cl::desc(
-          "Minimum value `vector.vscale` can possibly be at runtime."),
+      *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
       llvm::cl::init(1)};
-
   Option<unsigned> vscaleMax{
-      *this, "vscale-max",
-      llvm::cl::desc(
-          "Maximum value `vector.vscale` can possibly be at runtime."),
+      *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
       llvm::cl::init(16)};
 
   StringRef getArgument() const final { return "test-eliminate-vector-masks"; }

>From 8d0945123325012f615e40012df2adb8586ff480 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 22 Jul 2024 12:56:46 +0000
Subject: [PATCH 3/6] Comments

---
 mlir/test/Dialect/Vector/eliminate-masks.mlir | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir
index 51564d0ca1d4f1..ff859081fe5c80 100644
--- a/mlir/test/Dialect/Vector/eliminate-masks.mlir
+++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir
@@ -51,8 +51,8 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
   %c4_vscale = arith.muli %vscale, %c4 : index
   %extracted_slice = tensor.extract_slice %tensor[0] [%c4_vscale] [1] : tensor<1000xf32> to tensor<?xf32>
   %slice = scf.for %i = %c0 to %c1000 step %c4_vscale iter_args(%arg = %extracted_slice) -> tensor<?xf32> {
-    // This mask cannot be eliminated even though looking at the above operations
-    // it appears `tensor.dim` will always be c4_vscale (so the mask all-true).
+    // This mask cannot be eliminated even though looking at the operations above
+    // (this comment) it appears `tensor.dim` will always be c4_vscale (so the mask all-true).
     %dim = tensor.dim %arg, %c0 : tensor<?xf32>
     %mask = vector.create_mask %dim : vector<[4]xi1>
     "test.some_use"(%mask) : (vector<[4]xi1>) -> ()
@@ -77,6 +77,8 @@ func.func @negative_constant_dim_not_all_true()
   %c4 = arith.constant 4 : index
   %vscale = vector.vscale
   %c4_vscale = arith.muli %vscale, %c4 : index
+  // Since %c1 is a constant, this will be found not to be all-true via simple
+  // pattern matching.
   %mask = vector.create_mask %c1, %c4_vscale : vector<2x[4]xi1>
   "test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
   return
@@ -93,6 +95,8 @@ func.func @negative_constant_vscale_multiple_not_all_true() {
   %c3 = arith.constant 3 : index
   %vscale = vector.vscale
   %c3_vscale = arith.muli %vscale, %c3 : index
+  // Since %c3_vscale is a constant vscale multiple, this will be found not to
+  // be all-true via simple pattern matching.
   %mask = vector.create_mask %c2, %c3_vscale : vector<2x[4]xi1>
   "test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
   return

>From f6acf2d9b5f0ca5a341fc1b77d01bd7b5cb86133 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 23 Jul 2024 13:51:54 +0000
Subject: [PATCH 4/6] Review fixups

---
 .../Transforms/VectorMaskElimination.cpp      | 24 ++++++++++---------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
index 486784a9cf102b..9ad0de5cadeaee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
@@ -68,26 +68,28 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
   for (auto [i, dimSize] : unknownDims) {
     // Compute the lower bound for the unknown dimension (i.e. the smallest
     // value it could be).
-    auto lowerBound =
+    FailureOr<ConstantOrScalableBound> dimLowerBound =
         vector::ScalableValueBoundsConstraintSet::computeScalableBound(
             dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
             presburger::BoundType::LB);
-    if (failed(lowerBound))
+    if (failed(dimLowerBound))
       return failure();
-    auto boundSize = lowerBound->getSize();
-    if (failed(boundSize))
+    auto dimLowerBoundSize = dimLowerBound->getSize();
+    if (failed(dimLowerBoundSize))
       return failure();
-    if (boundSize->scalable) {
-      // If the lower bound is scalable and >= to the mask dim size then this
-      // dim is all-true.
-      if (boundSize->baseSize < maskTypeDimSizes[i])
+    if (dimLowerBoundSize->scalable) {
+      // If the lower bound is scalable and < the mask dim size then this dim is
+      // not all-true.
+      if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
         return failure();
     } else {
-      // If the lower bound is a constant and >= to the _fixed-size_ mask dim
-      // size then this dim is all-true.
+      // If the lower bound is a constant:
+      // - If the mask dim size is scalable then this dim is not all-true.
       if (maskTypeDimScalableFlags[i])
         return failure();
-      if (boundSize->baseSize < maskTypeDimSizes[i])
+      // - If the lower bound is < the _fixed-size_ mask dim size then this dim
+      // is not all-true.
+      if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
         return failure();
     }
   }

>From 692ce6e05ed546eaa90fadd10d7a5fd53b3e0eb7 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 5 Aug 2024 13:11:15 +0000
Subject: [PATCH 5/6] Fixups

---
 .../mlir/Dialect/Vector/IR/VectorOps.h        |  3 ++
 .../mlir/Dialect/Vector/IR/VectorOps.td       |  5 +++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 10 ++++++
 .../Transforms/VectorMaskElimination.cpp      | 17 +++++-----
 mlir/test/Dialect/Vector/eliminate-masks.mlir | 33 ++++++++++++++-----
 5 files changed, 52 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ac55433fadb2f4..c6e14373aa6223 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -56,6 +56,9 @@ namespace detail {
 struct BitmaskEnumStorage;
 } // namespace detail
 
+/// Predefined constant_mask kinds.
+enum class ConstantMaskKind { AllFalse = 0, AllTrue };
+
 /// Default callback to build a region with a 'vector.yield' terminator with no
 /// arguments.
 void buildTerminatedBody(OpBuilder &builder, Location loc);
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index cd19d356a6739d..80ec996e4b8a69 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2362,6 +2362,11 @@ def Vector_ConstantMaskOp :
     ```
   }];
 
+  let builders = [
+    // Build with mixed static/dynamic operands.
+    OpBuilder<(ins "VectorType":$type, "ConstantMaskKind":$kind)>
+  ];
+
   let extraClassDeclaration = [{
     /// Return the result type of this op.
     VectorType getVectorType() {
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a3b9f2091ab39..c7130babd571bb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5749,6 +5749,16 @@ void vector::TransposeOp::getCanonicalizationPatterns(
 // ConstantMaskOp
 //===----------------------------------------------------------------------===//
 
+void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
+                           VectorType type, ConstantMaskKind kind) {
+  assert(kind == ConstantMaskKind::AllTrue ||
+         kind == ConstantMaskKind::AllFalse);
+  build(builder, result, type,
+        kind == ConstantMaskKind::AllTrue
+            ? type.getShape()
+            : SmallVector<int64_t>(type.getRank(), 0));
+}
+
 LogicalResult ConstantMaskOp::verify() {
   auto resultType = llvm::cast<VectorType>(getResult().getType());
   // Check the corner case of 0-D vectors first.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
index 9ad0de5cadeaee..5ba294f1174aa9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
@@ -17,7 +17,9 @@ using namespace mlir;
 using namespace mlir::vector;
 namespace {
 
-/// If `value` is a constant multiple of `vector.vscale` return the multiplier.
+/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
+/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
+/// `std::nullopt`.
 std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
   if (value.getDefiningOp<vector::VectorScaleOp>())
     return 1;
@@ -78,17 +80,16 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
     if (failed(dimLowerBoundSize))
       return failure();
     if (dimLowerBoundSize->scalable) {
-      // If the lower bound is scalable and < the mask dim size then this dim is
-      // not all-true.
+      // 1. The lower bound, LB, is scalable. If LB is < the mask dim size then
+      // this dim is not all-true.
       if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
         return failure();
     } else {
-      // If the lower bound is a constant:
+      // 2. The lower bound, LB, is a constant.
       // - If the mask dim size is scalable then this dim is not all-true.
       if (maskTypeDimScalableFlags[i])
         return failure();
-      // - If the lower bound is < the _fixed-size_ mask dim size then this dim
-      // is not all-true.
+      // - If LB < the _fixed-size_ mask dim size then this dim is not all-true.
       if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
         return failure();
     }
@@ -97,8 +98,8 @@ LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
   // Replace createMaskOp with an all-true constant. This should result in the
   // mask being removed in most cases (as xfer ops + vector.mask have folds to
   // remove all-true masks).
-  auto allTrue = rewriter.create<arith::ConstantOp>(
-      createMaskOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
+  auto allTrue = rewriter.create<vector::ConstantMaskOp>(
+      createMaskOp.getLoc(), maskType, ConstantMaskKind::AllTrue);
   rewriter.replaceAllUsesWith(createMaskOp, allTrue);
   return success();
 }
diff --git a/mlir/test/Dialect/Vector/eliminate-masks.mlir b/mlir/test/Dialect/Vector/eliminate-masks.mlir
index ff859081fe5c80..0b89b0604faab1 100644
--- a/mlir/test/Dialect/Vector/eliminate-masks.mlir
+++ b/mlir/test/Dialect/Vector/eliminate-masks.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -split-input-file -test-eliminate-vector-masks  | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-eliminate-vector-masks --split-input-file | FileCheck %s
 
 // This tests a general pattern the vectorizer tends to emit.
 
 // CHECK-LABEL: @eliminate_redundant_masks_through_insert_and_extracts
-// CHECK: %[[ALL_TRUE_MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [4] : vector<[4]xi1>
 // CHECK: vector.transfer_read {{.*}} %[[ALL_TRUE_MASK]]
 // CHECK: vector.transfer_write {{.*}} %[[ALL_TRUE_MASK]]
 func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor<1x1000xf32>) {
@@ -40,7 +40,7 @@ func.func @eliminate_redundant_masks_through_insert_and_extracts(%tensor: tensor
 // -----
 
 // CHECK-LABEL: @negative_extract_slice_size_shrink
-// CHECK-NOT: arith.constant dense<true> : vector<[4]xi1>
+// CHECK-NOT: vector.constant_mask
 // CHECK: %[[MASK:.*]] = vector.create_mask
 // CHECK: "test.some_use"(%[[MASK]]) : (vector<[4]xi1>) -> ()
 func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
@@ -67,8 +67,25 @@ func.func @negative_extract_slice_size_shrink(%tensor: tensor<1000xf32>) {
 
 // -----
 
+// CHECK-LABEL: @trivially_all_true_case
+// CHECK: %[[ALL_TRUE_MASK:.*]] = vector.constant_mask [2, 4] : vector<2x[4]xi1>
+// CHECK: "test.some_use"(%[[ALL_TRUE_MASK]]) : (vector<2x[4]xi1>) -> ()
+func.func @trivially_all_true_case(%tensor: tensor<2x?xf32>)
+{
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  // Is found to be all true _without_ value bounds analysis.
+  %mask = vector.create_mask %c2, %c4_vscale : vector<2x[4]xi1>
+  "test.some_use"(%mask) : (vector<2x[4]xi1>) -> ()
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @negative_constant_dim_not_all_true
-// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
+// CHECK-NOT: vector.constant_mask
 // CHECK: %[[MASK:.*]] = vector.create_mask
 // CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
 func.func @negative_constant_dim_not_all_true()
@@ -87,7 +104,7 @@ func.func @negative_constant_dim_not_all_true()
 // -----
 
 // CHECK-LABEL: @negative_constant_vscale_multiple_not_all_true
-// CHECK-NOT: arith.constant dense<true> : vector<2x[4]xi1>
+// CHECK-NOT: vector.constant_mask
 // CHECK: %[[MASK:.*]] = vector.create_mask
 // CHECK: "test.some_use"(%[[MASK]]) : (vector<2x[4]xi1>) -> ()
 func.func @negative_constant_vscale_multiple_not_all_true() {
@@ -105,7 +122,7 @@ func.func @negative_constant_vscale_multiple_not_all_true() {
 // -----
 
 // CHECK-LABEL: @negative_value_bounds_fixed_dim_not_all_true
-// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
+// CHECK-NOT: vector.constant_mask
 // CHECK: %[[MASK:.*]] = vector.create_mask
 // CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
 func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>)
@@ -114,7 +131,7 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
   %c4 = arith.constant 4 : index
   %vscale = vector.vscale
   %c4_vscale = arith.muli %vscale, %c4 : index
-  // This is _very_ simple but since tensor.dim is not a constant value bounds
+  // This is _very_ simple, but since tensor.dim is not a constant, value bounds
   // will be used to resolve it.
   %dim = tensor.dim %tensor, %c0 : tensor<2x?xf32>
   %mask = vector.create_mask %dim, %c4_vscale : vector<3x[4]xi1>
@@ -125,7 +142,7 @@ func.func @negative_value_bounds_fixed_dim_not_all_true(%tensor: tensor<2x?xf32>
 // -----
 
 // CHECK-LABEL: @negative_value_bounds_scalable_dim_not_all_true
-// CHECK-NOT: arith.constant dense<true> : vector<3x[4]xi1>
+// CHECK-NOT: vector.constant_mask
 // CHECK: %[[MASK:.*]] = vector.create_mask
 // CHECK: "test.some_use"(%[[MASK]]) : (vector<3x[4]xi1>) -> ()
 func.func @negative_value_bounds_scalable_dim_not_all_true(%tensor: tensor<2x100xf32>) {

>From 259d1b039ef3e36bb3e7f9c312576b4949aa9056 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 6 Aug 2024 17:17:58 +0000
Subject: [PATCH 6/6] Share logic with `CreateMaskFolder`

The main thing shared here is the `getConstantVscaleMultiplier()`
matcher, I could not think of a good way to share all the logic as it's
somewhat different.
---
 .../mlir/Dialect/Vector/IR/VectorOps.h        |   5 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 110 ++++++++----------
 .../Transforms/VectorMaskElimination.cpp      |  18 ---
 3 files changed, 54 insertions(+), 79 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index c6e14373aa6223..ebe6cd4a62b4c5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -166,6 +166,11 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
 SmallVector<arith::ConstantIndexOp>
 getAsConstantIndexOps(ArrayRef<Value> values);
 
+/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
+/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
+/// `std::nullopt`.
+std::optional<int64_t> getConstantVscaleMultiplier(Value value);
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c7130babd571bb..250038d352ccea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5841,6 +5841,21 @@ LogicalResult CreateMaskOp::verify() {
   return success();
 }
 
+std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
+  if (value.getDefiningOp<vector::VectorScaleOp>())
+    return 1;
+  auto mul = value.getDefiningOp<arith::MulIOp>();
+  if (!mul)
+    return {};
+  auto lhs = mul.getLhs();
+  auto rhs = mul.getRhs();
+  if (lhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(rhs);
+  if (rhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(lhs);
+  return {};
+}
+
 namespace {
 
 /// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
@@ -5872,73 +5887,46 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 
   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
                                 PatternRewriter &rewriter) const override {
-    VectorType retTy = createMaskOp.getResult().getType();
-    bool isScalable = retTy.isScalable();
-
-    // Check every mask operand
-    for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
-      if (auto cst = getConstantIntValue(operand)) {
-        // Most basic case - this operand is a constant value. Note that for
-        // scalable dimensions, CreateMaskOp can be folded only if the
-        // corresponding operand is negative or zero.
-        if (retTy.getScalableDims()[opIdx] && *cst > 0)
-          return failure();
-
-        continue;
-      }
-
-      // Non-constant operands are not allowed for non-scalable vectors.
-      if (!isScalable)
-        return failure();
-
-      // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
-      // true" mask, so can also be treated as constant.
-      auto mul = operand.getDefiningOp<arith::MulIOp>();
-      if (!mul)
-        return failure();
-      auto mulLHS = mul.getRhs();
-      auto mulRHS = mul.getLhs();
-      bool isOneOpVscale =
-          (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
-           isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
-
-      auto isConstantValMatchingDim =
-          [=, dim = retTy.getShape()[opIdx]](Value operand) {
-            auto constantVal = getConstantIntValue(operand);
-            return (constantVal.has_value() && constantVal.value() == dim);
-          };
-
-      bool isOneOpConstantMatchingDim =
-          isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
-
-      if (!isOneOpVscale || !isOneOpConstantMatchingDim)
-        return failure();
+    VectorType maskType = createMaskOp.getVectorType();
+    ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
+    ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
+
+    // Special case: Rank zero shape.
+    constexpr std::array<int64_t, 1> rankZeroShape{1};
+    constexpr std::array<bool, 1> rankZeroScalableDims{false};
+    if (maskType.getRank() == 0) {
+      maskTypeDimSizes = rankZeroShape;
+      maskTypeDimScalableFlags = rankZeroScalableDims;
     }
 
-    // Gather constant mask dimension sizes.
-    SmallVector<int64_t, 4> maskDimSizes;
-    maskDimSizes.reserve(createMaskOp->getNumOperands());
-    for (auto [operand, maxDimSize] : llvm::zip_equal(
-             createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
-      std::optional dimSize = getConstantIntValue(operand);
-      if (!dimSize) {
-        // Although not a constant, it is safe to assume that `operand` is
-        // "vscale * maxDimSize".
-        maskDimSizes.push_back(maxDimSize);
-        continue;
-      }
-      int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
-      // If one of dim sizes is zero, set all dims to zero.
-      if (dimSize <= 0) {
-        maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
-        break;
+    SmallVector<int64_t, 4> constantDims;
+    for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
+      if (auto intSize = getConstantIntValue(dimSize)) {
+        // Non scalable dims can have any value. Scalable dims can only be zero.
+        if (intSize >= 0 && maskTypeDimScalableFlags[i])
+          return failure();
+        constantDims.push_back(*intSize);
+      } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
+        // Scalable dims must be all-true.
+        if (vscaleMultiplier < maskTypeDimSizes[i])
+          return failure();
+        constantDims.push_back(*vscaleMultiplier);
+      } else {
+        return failure();
       }
-      maskDimSizes.push_back(dimSizeVal);
     }
 
+    // Clamp values to constant_mask bounds.
+    for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
+      value = std::clamp<int64_t>(value, 0, maskDimSize);
+
+    // If one of dim sizes is zero, set all dims to zero.
+    if (llvm::is_contained(constantDims, 0))
+      constantDims.assign(constantDims.size(), 0);
+
     // Replace 'createMaskOp' with ConstantMaskOp.
-    rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
-                                                maskDimSizes);
+    rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
+                                                constantDims);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
index 5ba294f1174aa9..61cb8d929c2fc9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMaskElimination.cpp
@@ -17,24 +17,6 @@ using namespace mlir;
 using namespace mlir::vector;
 namespace {
 
-/// If `value` is a constant multiple of `vector.vscale` (e.g. `%cst *
-/// vector.vscale`), return the multiplier (`%cst`). Otherwise, return
-/// `std::nullopt`.
-std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
-  if (value.getDefiningOp<vector::VectorScaleOp>())
-    return 1;
-  auto mul = value.getDefiningOp<arith::MulIOp>();
-  if (!mul)
-    return {};
-  auto lhs = mul.getLhs();
-  auto rhs = mul.getRhs();
-  if (lhs.getDefiningOp<vector::VectorScaleOp>())
-    return getConstantIntValue(rhs);
-  if (rhs.getDefiningOp<vector::VectorScaleOp>())
-    return getConstantIntValue(lhs);
-  return {};
-}
-
 /// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
 /// All-true masks can then be eliminated by simple folds.
 LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,



More information about the Mlir-commits mailing list