[Mlir-commits] [mlir] [mlir][tensor] Rewrite tensor.pack as a constant (PR #93954)

Han-Chung Wang llvmlistbot at llvm.org
Fri May 31 10:54:21 PDT 2024


================
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
   }
 };
 
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+  using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+    if (!constOp)
+      return failure();
+    // Must be a dense constant.
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!denseAttr)
+      return failure();
+
+    // Bail out if the pack is used as a writing operation i.e.,
+    // the destination is not a tensor.empty.
+    if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+      return rewriter.notifyMatchFailure(packOp,
+                                         "expects empty tensor destination");
+    // Pack destination must have static shape.
+    if (!packOp.getDestType().hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          packOp, "expects destination with static shape");
+
+    // Pack with padding is not supported currently.
+    // TODO: Insert padding values as a part of rewrite.
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+
+    // If it is a splat constant, rewrite the pack directly.
+    if (denseAttr.isSplat()) {
+      DenseElementsAttr packedDenseShape =
+          denseAttr.reshape(packOp.getDestType());
+      rewriter.setInsertionPoint(constOp);
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+      return success();
+    }
+
+    // Constant contains non-splat dense values.
+    // Move the data into a new packed buffer. Each value is placed into its new
+    // position as defined by the pack operation.
+    ArrayRef<char> srcRawData = denseAttr.getRawData();
+    SmallVector<char> destRawData(srcRawData.size());
+
+    int64_t numberOfElements = denseAttr.getNumElements();
+    SmallVector<int64_t> strides =
+        computeStrides(packOp.getDestType().getShape());
+
+    // Parallelize raw data movement to speedup large constant packing.
+    parallelFor(
+        packOp.getContext(), 0, numberOfElements,
+        [&](size_t destLinearizedIdx) {
+          // Step 1: De-linearize destination index.
+          // f(lin) = tmp[A][B][C]
+          SmallVector<int64_t> destIndices =
+              delinearize(destLinearizedIdx, strides);
+
+          // Step 2: Arrange the indexes based on the packing information.
+          // Compute inverse of outerDimsPerm to bring the loops into the
+          // canonical form tmp[A][B][a][b].
+          if (!packOp.getOuterDimsPerm().empty()) {
+            SmallVector<int64_t> inversePermutation =
+                invertPermutationVector(packOp.getOuterDimsPerm());
+            SmallVector<int64_t> tileLoops;
+            for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
+              tileLoops.push_back(destIndices[i]);
+            applyPermutationToVector(tileLoops, inversePermutation);
+
+            SmallVector<int64_t> pointLoops;
+            for (size_t i = packOp.getSourceType().getRank();
+                 i < destIndices.size(); i++) {
+              pointLoops.push_back(destIndices[i]);
+            }
+
+            destIndices = tileLoops;
+            destIndices.append(pointLoops.begin(), pointLoops.end());
+          }
+          assert(destIndices.size() ==
+                 static_cast<size_t>(packOp.getDestType().getRank()));
+
+          // After interchanging the outermost tiled loop we end up in the
+          // canonical form tmp[A][B][a][b]. Squash the point loops with the
+          // tiled ones.
+          llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
+                                             packOp.getInnerDimsPos().end());
+          llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
+          // Map the position of the tiled loops with the point one.
+          // For example:
+          // [A][B] -> [A][B][a][b]
+          // entry: [A : 0] [a : 2]
+          // entry: [B : 1] [b : 3]
+          // [A][B] -> [A][B][b]
+          // entry: [B : 1] [b : 2]
+          for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos()))
+            mappingTileToPointLoops[tileLoop] = idx;
----------------
hanhanW wrote:

I think the comment is off? Do we have entries for inner dims? All the values in `getInnderDimsPos` are less than the rank of source, so we won't have `[a : 2]`, `[b : 3]` and [b : 2] entries. Do I misunderstand something?

https://github.com/llvm/llvm-project/pull/93954


More information about the Mlir-commits mailing list