[Mlir-commits] [mlir] 0e72d00 - [mlir][vector] Constant fold sub-vector extraction

Jakub Kuderski llvmlistbot at llvm.org
Fri Nov 25 10:40:52 PST 2022


Author: Jakub Kuderski
Date: 2022-11-25T13:39:45-05:00
New Revision: 0e72d00d1942a6aebf67efef47f0fda2437ce7ae

URL: https://github.com/llvm/llvm-project/commit/0e72d00d1942a6aebf67efef47f0fda2437ce7ae
DIFF: https://github.com/llvm/llvm-project/commit/0e72d00d1942a6aebf67efef47f0fda2437ce7ae.diff

LOG: [mlir][vector] Constant fold sub-vector extraction

This generalizes the existing fold for `ExtractOp(non-splat constant)`
to work with vector results. The vector case is handled by extracting
the subrange of attribute array.

My main use it to clean up code generated by the Wide Integer Emulation
pass.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D138690

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 22d7bdc3542e4..b71c2a0f06112 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -1623,24 +1624,33 @@ class ExtractOpScalarVectorConstantFolder final
       return failure();
 
     auto vecTy = sourceVector.getType().cast<VectorType>();
-    ArrayAttr positions = extractOp.getPosition();
     if (vecTy.isScalable())
       return failure();
-    // Do not allow extracting sub-vectors to limit the size of the generated
-    // constants.
-    if (vecTy.getRank() != static_cast<int64_t>(positions.size()))
-      return failure();
 
     // The splat case is handled by `ExtractOpSplatConstantFolder`.
     auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
     if (!dense || dense.isSplat())
       return failure();
 
-    // Calculate the linearized position.
-    int64_t elemPosition =
-        linearize(getI64SubArray(positions), computeStrides(vecTy.getShape()));
-    Attribute elementValue = *(dense.value_begin<Attribute>() + elemPosition);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, elementValue);
+    // Calculate the linearized position of the continous chunk of elements to
+    // extract.
+    llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
+    llvm::copy(getI64SubArray(extractOp.getPosition()),
+               completePositions.begin());
+    int64_t elemBeginPosition =
+        linearize(completePositions, computeStrides(vecTy.getShape()));
+    auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
+
+    Attribute newAttr;
+    if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) {
+      SmallVector<Attribute> elementValues(
+          denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
+      newAttr = DenseElementsAttr::get(resVecTy, elementValues);
+    } else {
+      newAttr = *denseValuesBegin;
+    }
+
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 3f3a35eb52b0d..eb1fb247bb269 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1471,6 +1471,19 @@ func.func @extract_2d_constant() -> (i32, i32, i32, i32) {
 
 // -----
 
+// CHECK-LABEL: func.func @extract_vector_2d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[3, 4, 5]> : vector<3xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]] : vector<3xi32>, vector<3xi32>
+func.func @extract_vector_2d_constant() -> (vector<3xi32>, vector<3xi32>) {
+  %cst = arith.constant dense<[[0, 1, 2], [3, 4, 5]]> : vector<2x3xi32>
+  %a = vector.extract %cst[0] : vector<2x3xi32>
+  %b = vector.extract %cst[1] : vector<2x3xi32>
+  return %a, %b : vector<3xi32>, vector<3xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @extract_3d_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant 1 : i32
@@ -1488,6 +1501,38 @@ func.func @extract_3d_constant() -> (i32, i32, i32, i32) {
 
 // -----
 
+// CHECK-LABEL: func.func @extract_vector_3d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\], \[4, 5\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[6, 7\], \[8, 9\], \[10, 11\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[8, 9]> : vector<2xi32>
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant dense<[10, 11]> : vector<2xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]] : vector<3x2xi32>, vector<3x2xi32>, vector<2xi32>, vector<2xi32>
+func.func @extract_vector_3d_constant() -> (vector<3x2xi32>, vector<3x2xi32>, vector<2xi32>, vector<2xi32>) {
+  %cst = arith.constant dense<[[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]> : vector<2x3x2xi32>
+  %a = vector.extract %cst[0] : vector<2x3x2xi32>
+  %b = vector.extract %cst[1] : vector<2x3x2xi32>
+  %c = vector.extract %cst[1, 1] : vector<2x3x2xi32>
+  %d = vector.extract %cst[1, 2] : vector<2x3x2xi32>
+  return %a, %b, %c, %d : vector<3x2xi32>, vector<3x2xi32>, vector<2xi32>, vector<2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_splat_vector_3d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<4> : vector<2xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<5> : vector<2xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]] : vector<2xi32>, vector<2xi32>, vector<2xi32>
+func.func @extract_splat_vector_3d_constant() -> (vector<2xi32>, vector<2xi32>, vector<2xi32>) {
+  %cst = arith.constant dense<[[[0, 0], [1, 1], [2, 2]], [[3, 3], [4, 4], [5, 5]]]> : vector<2x3x2xi32>
+  %a = vector.extract %cst[0, 0] : vector<2x3x2xi32>
+  %b = vector.extract %cst[1, 1] : vector<2x3x2xi32>
+  %c = vector.extract %cst[1, 2] : vector<2x3x2xi32>
+  return %a, %b, %c : vector<2xi32>, vector<2xi32>, vector<2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_extract_strided
 //  CHECK-SAME: %[[A:.*]]: vector<32x16x4xf16>
 //       CHECK: %[[V:.*]] = vector.extract %[[A]][9, 7] : vector<32x16x4xf16>


        


More information about the Mlir-commits mailing list