[Mlir-commits] [mlir] [mlir][tensor] Rewrite tensor.pack as a constant (PR #93954)

Han-Chung Wang llvmlistbot at llvm.org
Fri May 31 10:54:21 PDT 2024


================
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
   }
 };
 
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+  using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+    if (!constOp)
+      return failure();
+    // Must be a dense constant.
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!denseAttr)
+      return failure();
+
+    // Bail out if the pack is used as a writing operation i.e.,
+    // the destination is not a tensor.empty.
+    if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+      return rewriter.notifyMatchFailure(packOp,
+                                         "expects empty tensor destination");
+    // Pack destination must have static shape.
+    if (!packOp.getDestType().hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          packOp, "expects destination with static shape");
+
+    // Pack with padding is not supported currently.
+    // TODO: Insert padding values as a part of rewrite.
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+
+    // If it is a splat constant, rewrite the pack directly.
+    if (denseAttr.isSplat()) {
+      DenseElementsAttr packedDenseShape =
+          denseAttr.reshape(packOp.getDestType());
+      rewriter.setInsertionPoint(constOp);
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+      return success();
+    }
----------------
hanhanW wrote:

This case is already covered in folders.

https://github.com/llvm/llvm-project/blob/07bd43945789e3fc8f57d21484a7f683d17166f3/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L4278-L4287

https://github.com/llvm/llvm-project/pull/93954


More information about the Mlir-commits mailing list