[Mlir-commits] [mlir] 8c2ea14 - [mlir][vector] Fold scalar vector.extract of non-splat n-D constants
Jakub Kuderski
llvmlistbot at llvm.org
Tue Sep 13 17:31:28 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-13T20:30:50-04:00
New Revision: 8c2ea14436c8232fa2e496122cb1d9349b1d8737
URL: https://github.com/llvm/llvm-project/commit/8c2ea14436c8232fa2e496122cb1d9349b1d8737
DIFF: https://github.com/llvm/llvm-project/commit/8c2ea14436c8232fa2e496122cb1d9349b1d8737.diff
LOG: [mlir][vector] Fold scalar vector.extract of non-splat n-D constants
Add a new pattern to fold `vector.extract` over n-D constants that extract scalars.
The previous code handled ND splat constants only. The new pattern is conservative and does handle sub-vector constants.
This is to aid the `arith::EmulateWideInt` pass which emits a lot of 2-element vector constants.
Reviewed By: Mogball, dcaballe
Differential Revision: https://reviews.llvm.org/D133742
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5e1b95ee29070..37725cf3c90bf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1534,21 +1534,22 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
};
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
+class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- // Return if 'extractStridedSliceOp' operand is not defined by a
+ // Return if 'ExtractOp' operand is not defined by a splat vector
// ConstantOp.
- auto constantOp = extractOp.getVector().getDefiningOp<arith::ConstantOp>();
- if (!constantOp)
+ Value sourceVector = extractOp.getVector();
+ Attribute vectorCst;
+ if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
return failure();
- auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
- if (!dense)
+ auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
+ if (!splat)
return failure();
- Attribute newAttr = dense.getSplatValue<Attribute>();
+ Attribute newAttr = splat.getSplatValue<Attribute>();
if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
@@ -1556,11 +1557,71 @@ class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
}
};
+// Pattern to rewrite a ExtractOp(vector<...xT> ConstantOp)[...] -> ConstantOp,
+// where the position array specifies a scalar element.
+class ExtractOpScalarVectorConstantFolder final
+ : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // Return if 'ExtractOp' operand is not defined by a compatible vector
+ // ConstantOp.
+ Value sourceVector = extractOp.getVector();
+ Attribute vectorCst;
+ if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
+ return failure();
+
+ auto vecTy = sourceVector.getType().cast<VectorType>();
+ Type elemTy = vecTy.getElementType();
+ ArrayAttr positions = extractOp.getPosition();
+ if (vecTy.isScalable())
+ return failure();
+ // Do not allow extracting sub-vectors to limit the size of the generated
+ // constants.
+ if (vecTy.getRank() != static_cast<int64_t>(positions.size()))
+ return failure();
+ // TODO: Handle more element types, e.g., complex values.
+ if (!elemTy.isIntOrIndexOrFloat())
+ return failure();
+
+ // The splat case is handled by `ExtractOpSplatConstantFolder`.
+ auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
+ if (!dense || dense.isSplat())
+ return failure();
+
+ // Calculate the flattened position.
+ int64_t elemPosition = 0;
+ int64_t innerElems = 1;
+ for (auto [dimSize, positionInDim] :
+ llvm::reverse(llvm::zip(vecTy.getShape(), positions))) {
+ int64_t positionVal = positionInDim.cast<IntegerAttr>().getInt();
+ elemPosition += positionVal * innerElems;
+ innerElems *= dimSize;
+ }
+
+ Attribute newAttr;
+ if (vecTy.getElementType().isIntOrIndex()) {
+ auto values = to_vector(dense.getValues<APInt>());
+ newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]);
+ } else if (vecTy.getElementType().isa<FloatType>()) {
+ auto values = to_vector(dense.getValues<APFloat>());
+ newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]);
+ }
+ assert(newAttr && "Unhandled case");
+
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
+ results.add<ExtractOpSplatConstantFolder, ExtractOpScalarVectorConstantFolder,
+ ExtractOpFromBroadcast>(context);
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6fe6c2776f563..ac5b857938581 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1367,11 +1367,11 @@ func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
// -----
-// CHECK-LABEL: extract_constant
-// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
-// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
-// CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
-func.func @extract_constant() -> (vector<7xf32>, i32) {
+// CHECK-LABEL: func.func @extract_splat_constant
+// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
+// CHECK-NEXT: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
+func.func @extract_splat_constant() -> (vector<7xf32>, i32) {
%cst = arith.constant dense<2.000000e+00> : vector<29x7xf32>
%cst_1 = arith.constant dense<1> : vector<4x37x9xi32>
%0 = vector.extract %cst[2] : vector<29x7xf32>
@@ -1381,6 +1381,57 @@ func.func @extract_constant() -> (vector<7xf32>, i32) {
// -----
+// CHECK-LABEL: func.func @extract_1d_constant
+// CHECK-DAG: %[[I32CST:.*]] = arith.constant 3 : i32
+// CHECK-DAG: %[[IDXCST:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[F32CST:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-NEXT: return %[[I32CST]], %[[IDXCST]], %[[F32CST]] : i32, index, f32
+func.func @extract_1d_constant() -> (i32, index, f32) {
+ %icst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
+ %e = vector.extract %icst[2] : vector<4xi32>
+ %idx_cst = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+ %f = vector.extract %idx_cst[1] : vector<3xindex>
+ %fcst = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<3xf32>
+ %g = vector.extract %fcst[0] : vector<3xf32>
+ return %e, %f, %g : i32, index, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_2d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[BCST:.*]] = arith.constant 2 : i32
+// CHECK-DAG: %[[CCST:.*]] = arith.constant 3 : i32
+// CHECK-DAG: %[[DCST:.*]] = arith.constant 5 : i32
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32
+func.func @extract_2d_constant() -> (i32, i32, i32, i32) {
+ %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+ %a = vector.extract %cst[0, 0] : vector<2x3xi32>
+ %b = vector.extract %cst[0, 2] : vector<2x3xi32>
+ %c = vector.extract %cst[1, 0] : vector<2x3xi32>
+ %d = vector.extract %cst[1, 2] : vector<2x3xi32>
+ return %a, %b, %c, %d : i32, i32, i32, i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_3d_constant
+// CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[BCST:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[CCST:.*]] = arith.constant 9 : i32
+// CHECK-DAG: %[[DCST:.*]] = arith.constant 10 : i32
+// CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32
+func.func @extract_3d_constant() -> (i32, i32, i32, i32) {
+ %cst = arith.constant dense<[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]> : vector<2x3x2xi32>
+ %a = vector.extract %cst[0, 0, 0] : vector<2x3x2xi32>
+ %b = vector.extract %cst[0, 0, 1] : vector<2x3x2xi32>
+ %c = vector.extract %cst[1, 1, 1] : vector<2x3x2xi32>
+ %d = vector.extract %cst[1, 2, 0] : vector<2x3x2xi32>
+ return %a, %b, %c, %d : i32, i32, i32, i32
+}
+
+// -----
+
// CHECK-LABEL: extract_extract_strided
// CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
// CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>
More information about the Mlir-commits
mailing list