[Mlir-commits] [mlir] 9fa61cb - [mlir] Insert tensor.cast only when needed when folding tensor.cast into extract_slice.

Alexander Belyaev llvmlistbot at llvm.org
Mon Feb 27 06:18:22 PST 2023


Author: Alexander Belyaev
Date: 2023-02-27T15:18:01+01:00
New Revision: 9fa61cbb2e8dbf00e9320145d38331b7da8d552f

URL: https://github.com/llvm/llvm-project/commit/9fa61cbb2e8dbf00e9320145d38331b7da8d552f
DIFF: https://github.com/llvm/llvm-project/commit/9fa61cbb2e8dbf00e9320145d38331b7da8d552f.diff

LOG: [mlir] Insert tensor.cast only when needed when folding tensor.cast into extract_slice.

Differential Revision: https://reviews.llvm.org/D144868

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b6359a935782..b31ef3b98ed0 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1829,7 +1829,7 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
         }))
       return failure();
 
-    auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
+    auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
     if (!castOp)
       return failure();
 
@@ -1837,17 +1837,20 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
       return failure();
 
     /// Deduce the type of the result to use for the canonicalized operation.
+    Location loc = sliceOp.getLoc();
+    auto sliceOpType = sliceOp.getType();
     RankedTensorType resultType =
         ExtractSliceOp::inferCanonicalRankReducedResultType(
-            sliceOp.getType().getRank(), sliceOp.getSourceType(),
+            sliceOpType.getRank(), sliceOp.getSourceType(),
             sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
             sliceOp.getMixedStrides());
-    Value newSlice = rewriter.create<ExtractSliceOp>(
-        sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
+    Value newResult = rewriter.create<ExtractSliceOp>(
+        loc, resultType, castOp.getSource(), sliceOp.getOffsets(),
         sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
         sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
-                                                newSlice);
+    if (newResult.getType() != sliceOpType)
+      newResult = rewriter.create<CastOp>(loc, sliceOpType, newResult);
+    rewriter.replaceOp(sliceOp, newResult);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list