[Mlir-commits] [mlir] 5d05d29 - [mlir][vector] Add fold pattern for InsertOp(Constant into Constant)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 25 20:03:05 PST 2022
Author: Jakub Kuderski
Date: 2022-11-25T23:01:29-05:00
New Revision: 5d05d2966f4394e7f3fd5708ecc6f1b1c1156145
URL: https://github.com/llvm/llvm-project/commit/5d05d2966f4394e7f3fd5708ecc6f1b1c1156145
DIFF: https://github.com/llvm/llvm-project/commit/5d05d2966f4394e7f3fd5708ecc6f1b1c1156145.diff
LOG: [mlir][vector] Add fold pattern for InsertOp(Constant into Constant)
This pattern comes with vector size threshold to make sure we do not
introduce too many large constants.
This help clean up code created by the Wide Integer Emulation pass.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D138733
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eebf590d6586a..2f9bca6a0564e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1635,11 +1635,10 @@ class ExtractOpNonSplatConstantFolder final
if (!dense || dense.isSplat())
return failure();
- // Calculate the linearized position of the continous chunk of elements to
+ // Calculate the linearized position of the continuous chunk of elements to
// extract.
llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
- llvm::copy(getI64SubArray(extractOp.getPosition()),
- completePositions.begin());
+ copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
int64_t elemBeginPosition =
linearize(completePositions, computeStrides(vecTy.getShape()));
auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
@@ -2084,11 +2083,68 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
}
};
+// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
+class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+ // unless the source vector constant has a single use.
+ static constexpr int64_t vectorSizeFoldThreshold = 256;
+
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+ // Return if 'InsertOp' operand is not defined by a compatible vector
+ // ConstantOp.
+ TypedValue<VectorType> destVector = op.getDest();
+ Attribute vectorDestCst;
+ if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
+ return failure();
+
+ VectorType destTy = destVector.getType();
+ if (destTy.isScalable())
+ return failure();
+
+ // Make sure we do not create too many large constants.
+ if (destTy.getNumElements() > vectorSizeFoldThreshold &&
+ !destVector.hasOneUse())
+ return failure();
+
+ auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
+
+ Value sourceValue = op.getSource();
+ Attribute sourceCst;
+ if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
+ return failure();
+
+ // Calculate the linearized position of the continuous chunk of elements to
+ // insert.
+ llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+ copy(getI64SubArray(op.getPosition()), completePositions.begin());
+ int64_t insertBeginPosition =
+ linearize(completePositions, computeStrides(destTy.getShape()));
+
+ SmallVector<Attribute> insertedValues;
+ if (auto denseSource = sourceCst.dyn_cast<DenseElementsAttr>())
+ llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
+ else
+ insertedValues.push_back(sourceCst);
+
+ auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
+ copy(insertedValues, allValues.begin() + insertBeginPosition);
+ auto newAttr = DenseElementsAttr::get(destTy, allValues);
+
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
+ return success();
+ }
+};
+
} // namespace
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+ InsertOpConstantFolder>(context);
}
// Eliminates insert operations that produce values identical to their source
@@ -2744,13 +2800,12 @@ class StridedSliceNonSplatConstantFolder final
// Expand offsets and sizes to match the vector rank.
SmallVector<int64_t, 4> offsets(sliceRank, 0);
- llvm::copy(getI64SubArray(extractStridedSliceOp.getOffsets()),
- offsets.begin());
+ copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
- llvm::copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
+ copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
- // Calcualte the slice elements by enumerating all slice positions and
+ // Calculate the slice elements by enumerating all slice positions and
// linearizing them. The enumeration order is lexicographic which yields a
// sequence of monotonically increasing linearized position indices.
auto denseValuesBegin = dense.value_begin<Attribute>();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 19a06af5c9f8b..7aabcec231976 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1795,6 +1795,64 @@ func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
// -----
+// CHECK-LABEL: func.func @insert_1d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[9, 1, 2]> : vector<3xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 9, 2]> : vector<3xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[0, 1, 9]> : vector<3xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<3xi32>, vector<3xi32>, vector<3xi32>
+func.func @insert_1d_constant() -> (vector<3xi32>, vector<3xi32>, vector<3xi32>) {
+ %vcst = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+ %icst = arith.constant 9 : i32
+ %a = vector.insert %icst, %vcst[0] : i32 into vector<3xi32>
+ %b = vector.insert %icst, %vcst[1] : i32 into vector<3xi32>
+ %c = vector.insert %icst, %vcst[2] : i32 into vector<3xi32>
+ return %a, %b, %c : vector<3xi32>, vector<3xi32>, vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insert_2d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[99, 1, 2\], \[3, 4, 5\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[3, 4, 99\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[90, 91, 92\], \[3, 4, 5\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[0, 1, 2\], \[90, 91, 92\]\]}}> : vector<2x3xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]]
+func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) {
+ %vcst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+ %cst_scalar = arith.constant 99 : i32
+ %cst_1d = arith.constant dense<[90, 91, 92]> : vector<3xi32>
+ %a = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32>
+ %b = vector.insert %cst_scalar, %vcst[1, 2] : i32 into vector<2x3xi32>
+ %c = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32>
+ %d = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32>
+ return %a, %b, %c, %d : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insert_2d_splat_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
+// CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[0, 99, 0\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[33, 33, 33\], \[0, 0, 0\]\]}}> : vector<2x3xi32>
+// CHECK-DAG: %[[ECST:.*]] = arith.constant dense<{{\[\[0, 0, 0\], \[33, 33, 33\]\]}}> : vector<2x3xi32>
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]]
+func.func @insert_2d_splat_constant()
+ -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>) {
+ %vcst = arith.constant dense<0> : vector<2x3xi32>
+ %cst_zero = arith.constant 0 : i32
+ %cst_scalar = arith.constant 99 : i32
+ %cst_1d = arith.constant dense<33> : vector<3xi32>
+ %a = vector.insert %cst_zero, %vcst[0, 0] : i32 into vector<2x3xi32>
+ %b = vector.insert %cst_scalar, %vcst[0, 0] : i32 into vector<2x3xi32>
+ %c = vector.insert %cst_scalar, %vcst[1, 1] : i32 into vector<2x3xi32>
+ %d = vector.insert %cst_1d, %vcst[0] : vector<3xi32> into vector<2x3xi32>
+ %e = vector.insert %cst_1d, %vcst[1] : vector<3xi32> into vector<2x3xi32>
+ return %a, %b, %c, %d, %e : vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>, vector<2x3xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @insert_element_fold
// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32>
// CHECK: return %[[V]]
More information about the Mlir-commits
mailing list