[Mlir-commits] [mlir] d10dca6 - [mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders (#128040)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 6 10:24:41 PST 2025


Author: Kunwar Grover
Date: 2025-03-06T18:24:38Z
New Revision: d10dca6ba73c057e7c24624cd62ada8d75ecfd46

URL: https://github.com/llvm/llvm-project/commit/d10dca6ba73c057e7c24624cd62ada8d75ecfd46
DIFF: https://github.com/llvm/llvm-project/commit/d10dca6ba73c057e7c24624cd62ada8d75ecfd46.diff

LOG: [mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders (#128040)

This PR moves vector.insert canonicalizers for DenseElementsAttr (splat
and non splat case) to folders. Folders are local, and it's always
better to implement a folder than a canonicalizer.

This PR is mostly NFC-ish, because the functionality mostly remains
same, but is now run as part of a folder, which is why some tests are
changed, because GreedyPatternRewriter tries to fold by default.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ff323983a17c0..8e0e723cf4ed3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3019,94 +3019,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
-// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
-class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
-  // unless the source vector constant has a single use.
-  static constexpr int64_t vectorSizeFoldThreshold = 256;
-
-  LogicalResult matchAndRewrite(InsertOp op,
-                                PatternRewriter &rewriter) const override {
-    // TODO: Canonicalization for dynamic position not implemented yet.
-    if (op.hasDynamicPosition())
-      return failure();
+} // namespace
 
-    // Return if 'InsertOp' operand is not defined by a compatible vector
-    // ConstantOp.
-    TypedValue<VectorType> destVector = op.getDest();
-    Attribute vectorDestCst;
-    if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
-      return failure();
-    auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
-    if (!denseDest)
-      return failure();
+static Attribute
+foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
+                                  Attribute dstAttr,
+                                  int64_t maxVectorSizeFoldThreshold) {
+  if (insertOp.hasDynamicPosition())
+    return {};
 
-    VectorType destTy = destVector.getType();
-    if (destTy.isScalable())
-      return failure();
+  auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
+  if (!denseDst)
+    return {};
 
-    // Make sure we do not create too many large constants.
-    if (destTy.getNumElements() > vectorSizeFoldThreshold &&
-        !destVector.hasOneUse())
-      return failure();
+  if (!srcAttr) {
+    return {};
+  }
 
-    Value sourceValue = op.getSource();
-    Attribute sourceCst;
-    if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
-      return failure();
+  VectorType destTy = insertOp.getDestVectorType();
+  if (destTy.isScalable())
+    return {};
 
-    // Calculate the linearized position of the continuous chunk of elements to
-    // insert.
-    llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-    copy(op.getStaticPosition(), completePositions.begin());
-    int64_t insertBeginPosition =
-        linearize(completePositions, computeStrides(destTy.getShape()));
-
-    SmallVector<Attribute> insertedValues;
-    Type destEltType = destTy.getElementType();
-
-    // The `convertIntegerAttr` method specifically handles the case
-    // for `llvm.mlir.constant` which can hold an attribute with a
-    // 
diff erent type than the return type.
-    if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
-      for (auto value : denseSource.getValues<Attribute>())
-        insertedValues.push_back(convertIntegerAttr(value, destEltType));
-    } else {
-      insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
-    }
+  // Make sure we do not create too many large constants.
+  if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
+      !insertOp->hasOneUse())
+    return {};
 
-    auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
-    copy(insertedValues, allValues.begin() + insertBeginPosition);
-    auto newAttr = DenseElementsAttr::get(destTy, allValues);
+  // Calculate the linearized position of the continuous chunk of elements to
+  // insert.
+  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+  copy(insertOp.getStaticPosition(), completePositions.begin());
+  int64_t insertBeginPosition =
+      linearize(completePositions, computeStrides(destTy.getShape()));
 
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
-    return success();
-  }
+  SmallVector<Attribute> insertedValues;
+  Type destEltType = destTy.getElementType();
 
-private:
   /// Converts the expected type to an IntegerAttr if there's
   /// a mismatch.
-  Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
+  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
     if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
       if (intAttr.getType() != expectedType)
         return IntegerAttr::get(expectedType, intAttr.getInt());
     }
     return attr;
+  };
+
+  // The `convertIntegerAttr` method specifically handles the case
+  // for `llvm.mlir.constant` which can hold an attribute with a
+  // 
diff erent type than the return type.
+  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+    for (auto value : denseSource.getValues<Attribute>())
+      insertedValues.push_back(convertIntegerAttr(value, destEltType));
+  } else {
+    insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
   }
-};
 
-} // namespace
+  auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
+  copy(insertedValues, allValues.begin() + insertBeginPosition);
+  auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+  return newAttr;
+}
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
-              InsertOpConstantFolder>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
+  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+  // unless the source vector constant has a single use.
+  constexpr int64_t vectorSizeFoldThreshold = 256;
   // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
   // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
   // (type mismatch).
@@ -3118,6 +3102,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
   if (auto res = foldPoisonIndexInsertExtractOp(
           getContext(), adaptor.getStaticPosition(), kPoisonIndex))
     return res;
+  if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
+                                                   adaptor.getDest(),
+                                                   vectorSizeFoldThreshold)) {
+    return res;
+  }
 
   return {};
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..1ab28b9df2d19 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
 }
 
 // CHECK-LABEL: func @constant_mask_2d
-// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
-// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
-// CHECK: return %[[VAL_5]] : vector<4x4xi1>
+// CHECK: %[[VAL_0:.*]] = arith.constant 
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[VAL_0]] : vector<4x4xi1>
 
 // -----
 

diff  --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
index 7838543e151be..b5eb6e63f5a8d 100644
--- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
@@ -10,11 +10,9 @@ func.func @genbool_1d() -> vector<8xi1> {
 }
 
 // CHECK-LABEL: func @genbool_2d
-// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
-// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1>
-// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
-// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
-// CHECK: return %[[T1]] : vector<4x4xi1>
+// CHECK: %[[C0:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
+// CHECK: return %[[C0]] : vector<4x4xi1>
 
 func.func @genbool_2d() -> vector<4x4xi1> {
   %v = vector.constant_mask [2, 2] : vector<4x4xi1>
@@ -22,12 +20,9 @@ func.func @genbool_2d() -> vector<4x4xi1> {
 }
 
 // CHECK-LABEL: func @genbool_3d
-// CHECK-DAG: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
-// CHECK-DAG: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1>
-// CHECK-DAG: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1>
-// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
-// CHECK: return %[[T1]] : vector<2x3x4xi1>
+// CHECK: %[[C0:.*]] = arith.constant
+// CHECK-SAME{LITERAL}: dense<[[[true, true, true, false], [false, false, false, false], [false, false, false, false]], [[false, false, false, false], [false, false, false, false], [false, false, false, false]]]> : vector<2x3x4xi1>
+// CHECK: return %[[C0]] : vector<2x3x4xi1>
 
 func.func @genbool_3d() -> vector<2x3x4xi1> {
   %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>


        


More information about the Mlir-commits mailing list