[Mlir-commits] [mlir] [mlir][Vector][NFC] Move canonicalizers for DenseElementsAttr to folders (PR #127995)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 20 04:02:08 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-vector
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
This PR moves vector.extract canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer.
This PR is marked NFC, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.
---
Full diff: https://github.com/llvm/llvm-project/pull/127995.diff
5 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+46-76)
- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+7-19)
- (modified) mlir/test/Dialect/Vector/linearize.mlir (+2-2)
- (modified) mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir (+2-3)
- (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+5-9)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d5f3634377e4c..96ac7fe2fa9e2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2047,6 +2047,49 @@ static Attribute foldPoisonSrcExtractOp(Attribute srcAttr) {
return {};
}
+static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp,
+ Attribute srcAttr) {
+ auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
+ if (!denseAttr) {
+ return {};
+ }
+
+ if (denseAttr.isSplat()) {
+ Attribute newAttr = denseAttr.getSplatValue<Attribute>();
+ if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
+ newAttr = DenseElementsAttr::get(vecDstType, newAttr);
+ return newAttr;
+ }
+
+ auto vecTy = llvm::cast<VectorType>(extractOp.getSourceVectorType());
+ if (vecTy.isScalable())
+ return {};
+
+ if (extractOp.hasDynamicPosition()) {
+ return {};
+ }
+
+ // Calculate the linearized position of the continuous chunk of elements to
+ // extract.
+ llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+ copy(extractOp.getStaticPosition(), completePositions.begin());
+ int64_t elemBeginPosition =
+ linearize(completePositions, computeStrides(vecTy.getShape()));
+ auto denseValuesBegin =
+ denseAttr.value_begin<TypedAttr>() + elemBeginPosition;
+
+ TypedAttr newAttr;
+ if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
+ SmallVector<Attribute> elementValues(
+ denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+ newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+ } else {
+ newAttr = *denseValuesBegin;
+ }
+
+ return newAttr;
+}
+
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -2058,6 +2101,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return res;
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
return res;
+ if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
+ return res;
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
@@ -2121,80 +2166,6 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
}
};
-// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // Return if 'ExtractOp' operand is not defined by a splat vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
- if (!splat)
- return failure();
- TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
- if (auto vecDstType = llvm::dyn_cast<VectorType>(extractOp.getType()))
- newAttr = DenseElementsAttr::get(vecDstType, newAttr);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
-// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
-class ExtractOpNonSplatConstantFolder final
- : public OpRewritePattern<ExtractOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractOp extractOp,
- PatternRewriter &rewriter) const override {
- // TODO: Canonicalization for dynamic position not implemented yet.
- if (extractOp.hasDynamicPosition())
- return failure();
-
- // Return if 'ExtractOp' operand is not defined by a compatible vector
- // ConstantOp.
- Value sourceVector = extractOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
-
- auto vecTy = llvm::cast<VectorType>(sourceVector.getType());
- if (vecTy.isScalable())
- return failure();
-
- // The splat case is handled by `ExtractOpSplatConstantFolder`.
- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
- if (!dense || dense.isSplat())
- return failure();
-
- // Calculate the linearized position of the continuous chunk of elements to
- // extract.
- llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- copy(extractOp.getStaticPosition(), completePositions.begin());
- int64_t elemBeginPosition =
- linearize(completePositions, computeStrides(vecTy.getShape()));
- auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;
-
- TypedAttr newAttr;
- if (auto resVecTy = llvm::dyn_cast<VectorType>(extractOp.getType())) {
- SmallVector<Attribute> elementValues(
- denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
- newAttr = DenseElementsAttr::get(resVecTy, elementValues);
- } else {
- newAttr = *denseValuesBegin;
- }
-
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
- return success();
- }
-};
-
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
@@ -2332,8 +2303,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e66fbe968d9b0..cd83e1239fdda 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -32,14 +32,8 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<0> : vector<1xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX2:.+]] = vector.extract %[[CST_0]][0] : index from vector<1xindex>
-// CHECK-DAG: %[[IDX3:.+]] = vector.extract %[[CST_1]][0] : index from vector<3xindex>
-
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
// -----
@@ -175,16 +169,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-
-// CHECK-DAG: %[[CST_0:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[CST_1:.+]] = arith.constant dense<16> : vector<4x1xindex>
-// CHECK-DAG: %[[IDX0:.+]] = vector.extract %[[CST_1]][0, 0] : index from vector<4x1xindex>
-// CHECK-DAG: %[[IDX1:.+]] = vector.extract %[[CST_0]][0] : index from vector<4xindex>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[IDX0]], %[[IDX1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[C16]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_9]] : tensor<1x4xf32>
// CHECK: }
@@ -675,9 +665,7 @@ func.func @scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32
// CHECK-DAG: %[[PAD:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK-DAG: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
-// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 99b1bbab1eede..8e5ddbfffcdd9 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -310,12 +310,12 @@ func.func @test_vector_insert_scalable(%arg0: vector<2x8x[4]xf32>, %arg1: vector
// -----
// ALL-LABEL: test_vector_extract_scalar
-func.func @test_vector_extract_scalar() {
+func.func @test_vector_extract_scalar(%idx : index) {
%cst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
// ALL-NOT: vector.shuffle
// ALL: vector.extract
// ALL-NOT: vector.shuffle
- %0 = vector.extract %cst[0] : i32 from vector<4xi32>
+ %0 = vector.extract %cst[%idx] : i32 from vector<4xi32>
return
}
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index c5cb09b9aa9f9..b4ebb14b8829e 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -101,9 +101,8 @@ func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2:
// CHECK-LABEL: func @transfer_write_arith_constant(
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
-// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
-// CHECK: %[[extract:.*]] = vector.extract %[[cst]][0, 0] : f32 from vector<1x1xf32>
-// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+// CHECK: %[[cst:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK: memref.store %[[cst]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
%cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 20e9400ed698d..5be267c1be984 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -242,33 +242,29 @@ func.func @strided_gather(%base : memref<100x3xf32>,
// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
// CHECK-SAME: %[[VAL_4:.*]]: index,
// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> {
+// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
-// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
-// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
-// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
-// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
-// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
-// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK: scf.if %[[TRUE]] -> (vector<4xf32>)
// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/127995
More information about the Mlir-commits
mailing list