[Mlir-commits] [mlir] cb40d52 - [mlir][Linalg] Avoid using `tensor.cast` by default while folding `fill` with `pad`.

Mahesh Ravishankar llvmlistbot at llvm.org
Fri Nov 11 15:17:19 PST 2022


Author: Mahesh Ravishankar
Date: 2022-11-11T23:17:07Z
New Revision: cb40d5291e6113fddcfa5fc3861b7320f78dbfe8

URL: https://github.com/llvm/llvm-project/commit/cb40d5291e6113fddcfa5fc3861b7320f78dbfe8
DIFF: https://github.com/llvm/llvm-project/commit/cb40d5291e6113fddcfa5fc3861b7320f78dbfe8.diff

LOG: [mlir][Linalg] Avoid using `tensor.cast` by default while folding `fill` with `pad`.

This is unnecessary if the generated operation type already matches
the type of the replaced value.  Also use `OpFoldResult` to reduce the
number of cases the casts are needed.

Reviewed By: springerm, hanchung, antiagainst

Differential Revision: https://reviews.llvm.org/D137479

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5ebb81bed132d..32c8dd65678bf 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -496,17 +496,20 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
       return rewriter.notifyMatchFailure(
           padOp, "failed to reify tensor.pad op result shape");
 
-    auto oldResultType = padOp.getResultType();
-    SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
-                                        ShapedType::kDynamicSize);
+    SmallVector<OpFoldResult> newShape =
+        getAsOpFoldResult(reifiedShape.front());
     auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-        padOp.getLoc(), staticShape, oldResultType.getElementType(),
-        reifiedShape.front());
-    auto newFillOp = rewriter.create<FillOp>(
-        fillOp.getLoc(), ValueRange{padValue}, ValueRange{emptyTensor});
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
-                                                newFillOp.result());
-
+        padOp.getLoc(), newShape, padOp.getResultType().getElementType());
+    Value replacement =
+        rewriter
+            .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
+                            ValueRange{emptyTensor})
+            .getResult(0);
+    if (replacement.getType() != padOp.getResultType()) {
+      replacement = rewriter.create<tensor::CastOp>(
+          fillOp.getLoc(), padOp.getResultType(), replacement);
+    }
+    rewriter.replaceOp(padOp, replacement);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list