[Mlir-commits] [mlir] [mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders (PR #128040)

Kunwar Grover llvmlistbot at llvm.org
Thu Feb 20 10:00:38 PST 2025


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/128040

This PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer.

This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.

>From 2f57f48eee6c727c15c4dc757418d568004ec6e8 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Feb 2025 11:58:27 +0000
Subject: [PATCH 1/2] [mlir][Vector][NFC] Move canonicalizers for
 DenseElementsAttr to folders

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 122 +++++++-----------
 .../Linalg/vectorize-tensor-extract.mlir      |  26 +---
 mlir/test/Dialect/Vector/linearize.mlir       |   4 +-
 .../scalar-vector-transfer-to-memref.mlir     |   5 +-
 .../Vector/vector-gather-lowering.mlir        |  14 +-
 5 files changed, 62 insertions(+), 109 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..96ac7fe2fa9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
   return {};
 }
 
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+                                                   Attribute srcAttr) {
+  auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+  if (!denseAttr) {
+    return {};
+  }
+
+  if (denseAttr.isSplat()) {
+    Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+    return newAttr;
+  }
+
+  auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+  if (vecTy.isScalable())
+    return {};
+
+  if (extractOp.hasDynamicPosition()) {
+    return {};
+  }
+
+  // Calculate the linearized position of the continuous chunk of elements to
+  // extract.
+  llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+  copy(extractOp.getStaticPosition(), completePositions.begin());
+  int64_t elemBeginPosition =
+      linearize(completePositions, computeStrides(vecTy.getShape()));
+  auto denseValuesBegin =
+      denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+  TypedAttr newAttr;
+  if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+    SmallVector<Attribute> elementValues(
+        denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+    newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+  } else {
+    newAttr = *denseValuesBegin;
+  }
+
+  return newAttr;
+}
+
 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
   // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
   // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     return res;
   if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
     return res;
+  if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+    return res;
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
   if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   }
 };
 
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractOp' operand is not defined by a splat vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-    TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
-    if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
-      newAttr = DenseElementsAttr::get(vecDstType, newAttr);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp extractOp,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
-    // Return if 'ExtractOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    Value sourceVector = extractOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
-    if (vecTy.isScalable())
-      return failure();
-
-    // The splat case is handled by `ExtractOpSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // Calculate the linearized position of the continuous chunk of elements to
-    // extract.
-    llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
-    copy(extractOp.getStaticPosition(), completePositions.begin());
-    int64_t elemBeginPosition =
-        linearize(completePositions, computeStrides(vecTy.getShape()));
-    auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
-    TypedAttr newAttr;
-    if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
-      SmallVector<Attribute> elementValues(
-          denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
-      newAttr = DenseElementsAttr::get(resVecTy, elementValues);
-    } else {
-      newAttr = *denseValuesBegin;
-    }
-
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
 class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
 public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
 
 // CHECK-DAG:  %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:  %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG:  %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
 
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
 // CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
  // -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
 // CHECK-SAME:                                                                 %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                                 %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG:       %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG:       %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG:       %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 
-// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_9]] : tensor<1x4xf32>
 // CHECK:         }
 
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
 // CHECK-DAG:       %[[PAD:.*]] = arith.constant 0 : i32
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG:       %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK:           %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
 // CHECK:           %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
 // CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
 // -----
 
 // ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
   %cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
   // ALL-NOT: vector.shuffle
   // ALL:     vector.extract
   // ALL-NOT: vector.shuffle
-  %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+  %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
 
 // CHECK-LABEL: func @transfer_write_arith_constant(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-//       CHECK:   %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+//       CHECK:   memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
 func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
   %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
   vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -242,33 +242,29 @@ func.func @strided_gather(%base : memref<100x3xf32>,
 // CHECK-SAME:                         %[[IDXS:.*]]: vector<4xindex>,
 // CHECK-SAME:                         %[[VAL_4:.*]]: index,
 // CHECK-SAME:                         %[[VAL_5:.*]]: index) -> vector<4xf32> {
+// CHECK:           %[[TRUE:.*]] = arith.constant true
 // CHECK:           %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
-// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
 
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
 // CHECK:           %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
 
-// CHECK:           %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
 
-// CHECK:           %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
 
-// CHECK:           %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
 
-// CHECK:           %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
 // CHECK:           %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
-// CHECK:           scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK:           scf.if %[[TRUE]] -> (vector<4xf32>)
 // CHECK:             %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
 

>From 0d08560ec896d5e1256549a10e2af33e87f56b6b Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Feb 2025 17:55:41 +0000
Subject: [PATCH 2/2] [mlir][Vector] Move vector.insert canonicalizers for
 DenseElementsAttr to folders

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 119 ++++++++----------
 .../VectorToLLVM/vector-to-llvm.mlir          |  10 +-
 .../vector-mask-lowering-transforms.mlir      |  17 +--
 3 files changed, 63 insertions(+), 83 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 96ac7fe2fa9e2..f21ad23a03c6e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3013,94 +3013,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
-// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
-class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
-  // unless the source vector constant has a single use.
-  static constexpr int64_t vectorSizeFoldThreshold = 256;
-
-  LogicalResult matchAndRewrite(InsertOp op,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (op.hasDynamicPosition())
-      return failure();
+} // namespace
 
-    // Return if 'InsertOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    TypedValue<VectorType> destVector = op.getDest();
-    Attribute vectorDestCst;
-    if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
-      return failure();
-    auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
-    if (!denseDest)
-      return failure();
+static Attribute
+foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
+                                  Attribute dstAttr,
+                                  int64_t maxVectorSizeFoldThreshold) {
+  if (insertOp.hasDynamicPosition())
+    return {};
 
-    VectorType destTy = destVector.getType();
-    if (destTy.isScalable())
-      return failure();
+  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
+  if (!denseDst)
+    return {};
 
-    // Make sure we do not create too many large constants.
-    if (destTy.getNumElements() > vectorSizeFoldThreshold &&
-        !destVector.hasOneUse())
-      return failure();
+  if (!srcAttr) {
+    return {};
+  }
 
-    Value sourceValue = op.getSource();
-    Attribute sourceCst;
-    if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
-      return failure();
+  VectorType destTy = insertOp.getDestVectorType();
+  if (destTy.isScalable())
+    return {};
 
-    // Calculate the linearized position of the continuous chunk of elements to
-    // insert.
-    llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-    copy(op.getStaticPosition(), completePositions.begin());
-    int64_t insertBeginPosition =
-        linearize(completePositions, computeStrides(destTy.getShape()));
-
-    SmallVector<Attribute> insertedValues;
-    Type destEltType = destTy.getElementType();
-
-    // The `convertIntegerAttr` method specifically handles the case
-    // for `llvm.mlir.constant` which can hold an attribute with a
-    // different type than the return type.
-    if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
-      for (auto value : denseSource.getValues<Attribute>())
-        insertedValues.push_back(convertIntegerAttr(value, destEltType));
-    } else {
-      insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
-    }
+  // Make sure we do not create too many large constants.
+  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
+      !insertOp->hasOneUse())
+    return {};
 
-    auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
-    copy(insertedValues, allValues.begin() + insertBeginPosition);
-    auto newAttr = DenseElementsAttr::get(destTy, allValues);
+  // Calculate the linearized position of the continuous chunk of elements to
+  // insert.
+  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+  copy(insertOp.getStaticPosition(), completePositions.begin());
+  int64_t insertBeginPosition =
+      linearize(completePositions, computeStrides(destTy.getShape()));
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
-    return success();
-  }
+  SmallVector<Attribute> insertedValues;
+  Type destEltType = destTy.getElementType();
 
-private:
   /// Converts the expected type to an IntegerAttr if there's
   /// a mismatch.
-  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
+  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
     if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
       if (intAttr.getType() != expectedType)
         return IntegerAttr::get(expectedType, intAttr.getInt());
     }
     return attr;
+  };
+
+  // The `convertIntegerAttr` method specifically handles the case
+  // for `llvm.mlir.constant` which can hold an attribute with a
+  // different type than the return type.
+  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+    for (auto value : denseSource.getValues<Attribute>())
+      insertedValues.push_back(convertIntegerAttr(value, destEltType));
+  } else {
+    insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
   }
-};
 
-} // namespace
+  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
+  copy(insertedValues, allValues.begin() + insertBeginPosition);
+  auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+  return newAttr;
+}
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
-              InsertOpConstantFolder>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+  // unless the source vector constant has a single use.
+  constexpr int64_t vectorSizeFoldThreshold = 256;
   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
   // (type mismatch).
@@ -3112,6 +3096,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
+  if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
+                                                   adaptor.getDest(),
+                                                   vectorSizeFoldThreshold)) {
+    return res;
+  }
 
   return {};
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..1ab28b9df2d19 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
 }
 
 // CHECK-LABEL: func @constant_mask_2d
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
-// CHECK: return %[[VAL_5]] : vector<4x4xi1>
+// CHECK: %[[VAL_0:.*]] = arith.constant 
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[VAL_0]] : vector<4x4xi1>
 
 // -----
 
diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
index 7838543e151be..b5eb6e63f5a8d 100644
--- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
@@ -10,11 +10,9 @@ func.func @genbool_1d() -> vector<8xi1> {
 }
 
 // CHECK-LABEL: func @genbool_2d
-// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
-// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
-// CHECK: return %[[T1]] : vector<4x4xi1>
+// CHECK: %[[C0:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[C0]] : vector<4x4xi1>
 
 func.func @genbool_2d() -> vector<4x4xi1> {
   %v = vector.constant_mask [2, 2] : vector<4x4xi1>
@@ -22,12 +20,9 @@ func.func @genbool_2d() -> vector<4x4xi1> {
 }
 
 // CHECK-LABEL: func @genbool_3d
-// CHECK-DAG: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
-// CHECK-DAG: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1>
-// CHECK-DAG: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1>
-// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
-// CHECK: return %[[T1]] : vector<2x3x4xi1>
+// CHECK: %[[C0:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[[true, true, true, false], [false, false, false, false], [false, false, false, false]], [[false, false, false, false], [false, false, false, false], [false, false, false, false]]]> : vector<2x3x4xi1>
+// CHECK: return %[[C0]] : vector<2x3x4xi1>
 
 func.func @genbool_3d() -> vector<2x3x4xi1> {
   %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>



More information about the Mlir-commits mailing list