[Mlir-commits] [mlir] 8c2ea14 - [mlir][vector] Fold scalar vector.extract of non-splat n-D constants

Jakub Kuderski llvmlistbot at llvm.org
Tue Sep 13 17:31:28 PDT 2022


Author: Jakub Kuderski
Date: 2022-09-13T20:30:50-04:00
New Revision: 8c2ea14436c8232fa2e496122cb1d9349b1d8737

URL: https://github.com/llvm/llvm-project/commit/8c2ea14436c8232fa2e496122cb1d9349b1d8737
DIFF: https://github.com/llvm/llvm-project/commit/8c2ea14436c8232fa2e496122cb1d9349b1d8737.diff

LOG: [mlir][vector] Fold scalar vector.extract of non-splat n-D constants

Add a new pattern to fold `vector.extract` over n-D constants that extract scalars.
The previous code handled ND splat constants only. The new pattern is conservative and does handle sub-vector constants.

This is to aid the `arith::EmulateWideInt` pass which emits a lot of 2-element vector constants.

Reviewed By: Mogball, dcaballe

Differential Revision: https://reviews.llvm.org/D133742

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5e1b95ee29070..37725cf3c90bf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1534,21 +1534,22 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
 };
 
 // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
-class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
+class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // Return if 'ExtractOp' operand is not defined by a splat vector
     // ConstantOp.
-    auto constantOp = extractOp.getVector().getDefiningOp<arith::ConstantOp>();
-    if (!constantOp)
+    Value sourceVector = extractOp.getVector();
+    Attribute vectorCst;
+    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
       return failure();
-    auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
-    if (!dense)
+    auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
+    if (!splat)
       return failure();
-    Attribute newAttr = dense.getSplatValue<Attribute>();
+    Attribute newAttr = splat.getSplatValue<Attribute>();
     if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
       newAttr = DenseElementsAttr::get(vecDstType, newAttr);
     rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
@@ -1556,11 +1557,71 @@ class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+// Pattern to rewrite a ExtractOp(vector<...xT> ConstantOp)[...] -> ConstantOp,
+// where the position array specifies a scalar element.
+class ExtractOpScalarVectorConstantFolder final
+    : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Return if 'ExtractOp' operand is not defined by a compatible vector
+    // ConstantOp.
+    Value sourceVector = extractOp.getVector();
+    Attribute vectorCst;
+    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
+      return failure();
+
+    auto vecTy = sourceVector.getType().cast<VectorType>();
+    Type elemTy = vecTy.getElementType();
+    ArrayAttr positions = extractOp.getPosition();
+    if (vecTy.isScalable())
+      return failure();
+    // Do not allow extracting sub-vectors to limit the size of the generated
+    // constants.
+    if (vecTy.getRank() != static_cast<int64_t>(positions.size()))
+      return failure();
+    // TODO: Handle more element types, e.g., complex values.
+    if (!elemTy.isIntOrIndexOrFloat())
+      return failure();
+
+    // The splat case is handled by `ExtractOpSplatConstantFolder`.
+    auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
+    if (!dense || dense.isSplat())
+      return failure();
+
+    // Calculate the flattened position.
+    int64_t elemPosition = 0;
+    int64_t innerElems = 1;
+    for (auto [dimSize, positionInDim] :
+         llvm::reverse(llvm::zip(vecTy.getShape(), positions))) {
+      int64_t positionVal = positionInDim.cast<IntegerAttr>().getInt();
+      elemPosition += positionVal * innerElems;
+      innerElems *= dimSize;
+    }
+
+    Attribute newAttr;
+    if (vecTy.getElementType().isIntOrIndex()) {
+      auto values = to_vector(dense.getValues<APInt>());
+      newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]);
+    } else if (vecTy.getElementType().isa<FloatType>()) {
+      auto values = to_vector(dense.getValues<APFloat>());
+      newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]);
+    }
+    assert(newAttr && "Unhandled case");
+
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
+  results.add<ExtractOpSplatConstantFolder, ExtractOpScalarVectorConstantFolder,
+              ExtractOpFromBroadcast>(context);
 }
 
 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6fe6c2776f563..ac5b857938581 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1367,11 +1367,11 @@ func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
 
 // -----
 
-// CHECK-LABEL: extract_constant
-//       CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
-//       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
-//       CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
-func.func @extract_constant() -> (vector<7xf32>, i32) {
+// CHECK-LABEL: func.func @extract_splat_constant
+//   CHECK-DAG:   %[[CST1:.*]] = arith.constant 1 : i32
+//   CHECK-DAG:   %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
+//  CHECK-NEXT:   return %[[CST0]], %[[CST1]] : vector<7xf32>, i32
+func.func @extract_splat_constant() -> (vector<7xf32>, i32) {
   %cst = arith.constant dense<2.000000e+00> : vector<29x7xf32>
   %cst_1 = arith.constant dense<1> : vector<4x37x9xi32>
   %0 = vector.extract %cst[2] : vector<29x7xf32>
@@ -1381,6 +1381,57 @@ func.func @extract_constant() -> (vector<7xf32>, i32) {
 
 // -----
 
+// CHECK-LABEL: func.func @extract_1d_constant
+//   CHECK-DAG: %[[I32CST:.*]] = arith.constant 3 : i32
+//   CHECK-DAG: %[[IDXCST:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[F32CST:.*]] = arith.constant 2.000000e+00 : f32
+//  CHECK-NEXT: return %[[I32CST]], %[[IDXCST]], %[[F32CST]] : i32, index, f32
+func.func @extract_1d_constant() -> (i32, index, f32) {
+  %icst = arith.constant dense<[1, 2, 3, 4]> : vector<4xi32>
+  %e = vector.extract %icst[2] : vector<4xi32>
+  %idx_cst = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+  %f = vector.extract %idx_cst[1] : vector<3xindex>
+  %fcst = arith.constant dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<3xf32>
+  %g = vector.extract %fcst[0] : vector<3xf32>
+  return %e, %f, %g : i32, index, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_2d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant 2 : i32
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant 3 : i32
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant 5 : i32
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32
+func.func @extract_2d_constant() -> (i32, i32, i32, i32) {
+  %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+  %a = vector.extract %cst[0, 0] : vector<2x3xi32>
+  %b = vector.extract %cst[0, 2] : vector<2x3xi32>
+  %c = vector.extract %cst[1, 0] : vector<2x3xi32>
+  %d = vector.extract %cst[1, 2] : vector<2x3xi32>
+  return %a, %b, %c, %d : i32, i32, i32, i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_3d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant 1 : i32
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant 9 : i32
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant 10 : i32
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : i32, i32, i32, i32
+func.func @extract_3d_constant() -> (i32, i32, i32, i32) {
+  %cst = arith.constant dense<[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]> : vector<2x3x2xi32>
+  %a = vector.extract %cst[0, 0, 0] : vector<2x3x2xi32>
+  %b = vector.extract %cst[0, 0, 1] : vector<2x3x2xi32>
+  %c = vector.extract %cst[1, 1, 1] : vector<2x3x2xi32>
+  %d = vector.extract %cst[1, 2, 0] : vector<2x3x2xi32>
+  return %a, %b, %c, %d : i32, i32, i32, i32
+}
+
+// -----
+
 // CHECK-LABEL: extract_extract_strided
 //  CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
 //       CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>


        


More information about the Mlir-commits mailing list