[Mlir-commits] [mlir] 933fde3 - [mlir][tensor][NFC] Simplify extract_slice(cast) folder

Matthias Springer llvmlistbot at llvm.org
Mon Jul 31 06:13:41 PDT 2023


Author: Matthias Springer
Date: 2023-07-31T15:07:49+02:00
New Revision: 933fde3d1c6217ad6f40f8668832585ceba1929c

URL: https://github.com/llvm/llvm-project/commit/933fde3d1c6217ad6f40f8668832585ceba1929c
DIFF: https://github.com/llvm/llvm-project/commit/933fde3d1c6217ad6f40f8668832585ceba1929c.diff

LOG: [mlir][tensor][NFC] Simplify extract_slice(cast) folder

The type computation part is not needed.

Differential Revision: https://reviews.llvm.org/D156652

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index acd6a7271bce41..614f39b5d29190 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1920,20 +1920,14 @@ class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
     if (!canFoldIntoConsumerOp(castOp))
       return failure();
 
-    /// Deduce the type of the result to use for the canonicalized operation.
+    // Create folded extract.
     Location loc = sliceOp.getLoc();
-    auto sliceOpType = sliceOp.getType();
-    RankedTensorType resultType =
-        ExtractSliceOp::inferCanonicalRankReducedResultType(
-            sliceOpType.getRank(), sliceOp.getSourceType(),
-            sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
-            sliceOp.getMixedStrides());
     Value newResult = rewriter.create<ExtractSliceOp>(
-        loc, resultType, castOp.getSource(), sliceOp.getOffsets(),
+        loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
         sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
         sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
-    if (newResult.getType() != sliceOpType)
-      newResult = rewriter.create<CastOp>(loc, sliceOpType, newResult);
+    if (newResult.getType() != sliceOp.getType())
+      newResult = rewriter.create<CastOp>(loc, sliceOp.getType(), newResult);
     rewriter.replaceOp(sliceOp, newResult);
     return success();
   }


        


More information about the Mlir-commits mailing list