[Mlir-commits] [mlir] 5767e4d - [MLIR][NFC] Return MemRefType in memref.subview return type inference functions (#120024)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 14 04:58:24 PST 2025


Author: Tomás Longeri
Date: 2025-02-14T12:58:20Z
New Revision: 5767e4d4ca55cbd21dfbb13825993295146b6302

URL: https://github.com/llvm/llvm-project/commit/5767e4d4ca55cbd21dfbb13825993295146b6302
DIFF: https://github.com/llvm/llvm-project/commit/5767e4d4ca55cbd21dfbb13825993295146b6302.diff

LOG: [MLIR][NFC] Return MemRefType in memref.subview return type inference functions (#120024)

Avoids the need for cast, and matches the extra build functions, which
take a `MemRefType`

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
    mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c3ee3968abc16..4c8a214049ea9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2081,14 +2081,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     /// A subview result type can be fully inferred from the source type and the
     /// static representation of offsets, sizes and strides. Special sentinels
     /// encode the dynamic case.
-    static Type inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<int64_t> staticOffsets,
-                                ArrayRef<int64_t> staticSizes,
-                                ArrayRef<int64_t> staticStrides);
-    static Type inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<OpFoldResult> staticOffsets,
-                                ArrayRef<OpFoldResult> staticSizes,
-                                ArrayRef<OpFoldResult> staticStrides);
+    static MemRefType inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<int64_t> staticOffsets,
+                                      ArrayRef<int64_t> staticSizes,
+                                      ArrayRef<int64_t> staticStrides);
+    static MemRefType inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<OpFoldResult> staticOffsets,
+                                      ArrayRef<OpFoldResult> staticSizes,
+                                      ArrayRef<OpFoldResult> staticStrides);
 
     /// A rank-reducing result type can be inferred from the desired result
     /// shape. Only the layout map is inferred.
@@ -2097,16 +2097,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     /// and the desired sizes. In case there are more "ones" among the sizes
     /// than the 
diff erence in source/result rank, it is not clear which dims of
     /// size one should be dropped.
-    static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceMemRefType,
-                                           ArrayRef<int64_t> staticOffsets,
-                                           ArrayRef<int64_t> staticSizes,
-                                           ArrayRef<int64_t> staticStrides);
-    static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceMemRefType,
-                                           ArrayRef<OpFoldResult> staticOffsets,
-                                           ArrayRef<OpFoldResult> staticSizes,
-                                           ArrayRef<OpFoldResult> staticStrides);
+    static MemRefType inferRankReducedResultType(
+        ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+        ArrayRef<int64_t> staticOffsets,
+        ArrayRef<int64_t> staticSizes,
+        ArrayRef<int64_t> staticStrides);
+    static MemRefType inferRankReducedResultType(
+        ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+        ArrayRef<OpFoldResult> staticOffsets,
+        ArrayRef<OpFoldResult> staticSizes,
+        ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e0930abc1887d..11597505e7888 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2702,10 +2702,10 @@ void SubViewOp::getAsmResultNames(
 /// A subview result type can be fully inferred from the source type and the
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<int64_t> staticOffsets,
-                                ArrayRef<int64_t> staticSizes,
-                                ArrayRef<int64_t> staticStrides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<int64_t> staticOffsets,
+                                      ArrayRef<int64_t> staticSizes,
+                                      ArrayRef<int64_t> staticStrides) {
   unsigned rank = sourceMemRefType.getRank();
   (void)rank;
   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
@@ -2744,10 +2744,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
                          sourceMemRefType.getMemorySpace());
 }
 
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<OpFoldResult> offsets,
-                                ArrayRef<OpFoldResult> sizes,
-                                ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<OpFoldResult> offsets,
+                                      ArrayRef<OpFoldResult> sizes,
+                                      ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2763,13 +2763,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
                                     staticSizes, staticStrides);
 }
 
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceRankedTensorType,
-                                           ArrayRef<int64_t> offsets,
-                                           ArrayRef<int64_t> sizes,
-                                           ArrayRef<int64_t> strides) {
-  auto inferredType = llvm::cast<MemRefType>(
-      inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+MemRefType SubViewOp::inferRankReducedResultType(
+    ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+    ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+    ArrayRef<int64_t> strides) {
+  MemRefType inferredType =
+      inferResultType(sourceRankedTensorType, offsets, sizes, strides);
   assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
          "expected ");
   if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
@@ -2795,11 +2794,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
                          inferredType.getMemorySpace());
 }
 
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceRankedTensorType,
-                                           ArrayRef<OpFoldResult> offsets,
-                                           ArrayRef<OpFoldResult> sizes,
-                                           ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferRankReducedResultType(
+    ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+    ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+    ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2826,8 +2824,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
   auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
   // Structuring implementation this way avoids duplication between builders.
   if (!resultType) {
-    resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
-        sourceMemRefType, staticOffsets, staticSizes, staticStrides));
+    resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
+                                            staticSizes, staticStrides);
   }
   result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2992,8 +2990,8 @@ LogicalResult SubViewOp::verify() {
 
   // Compute the expected result type, assuming that there are no rank
   // reductions.
-  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
-      baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
+  MemRefType expectedType = SubViewOp::inferResultType(
+      baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
 
   // Verify all properties of a shaped type: rank, element type and dimension
   // sizes. This takes into account potential rank reductions.
@@ -3075,8 +3073,8 @@ static MemRefType getCanonicalSubViewResultType(
     MemRefType currentResultType, MemRefType currentSourceType,
     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
-  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
-      sourceType, mixedOffsets, mixedSizes, mixedStrides));
+  MemRefType nonRankReducedType = SubViewOp::inferResultType(
+      sourceType, mixedOffsets, mixedSizes, mixedStrides);
   FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
       currentSourceType, currentResultType, mixedSizes);
   if (failed(unusedDims))
@@ -3110,9 +3108,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
   SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
-  auto targetType =
-      llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
-          targetShape, memrefType, offsets, sizes, strides));
+  MemRefType targetType = SubViewOp::inferRankReducedResultType(
+      targetShape, memrefType, offsets, sizes, strides);
   return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
                                            sizes, strides);
 }
@@ -3256,11 +3253,11 @@ struct SubViewReturnTypeCanonicalizer {
                         ArrayRef<OpFoldResult> mixedSizes,
                         ArrayRef<OpFoldResult> mixedStrides) {
     // Infer a memref type without taking into account any rank reductions.
-    auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
-                                            mixedSizes, mixedStrides);
+    MemRefType resTy = SubViewOp::inferResultType(
+        op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
     if (!resTy)
       return {};
-    MemRefType nonReducedType = cast<MemRefType>(resTy);
+    MemRefType nonReducedType = resTy;
 
     // Directly return the non-rank reduced type if there are no dropped dims.
     llvm::SmallBitVector droppedDims = op.getDroppedDims();

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index e237858d208a0..21361d2e9a2d7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter,
                    UnrealizedConversionCastOp conversionOp, SubViewOp op) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(op);
-  auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+  MemRefType newResultType = SubViewOp::inferRankReducedResultType(
       op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
-      op.getMixedSizes(), op.getMixedStrides()));
+      op.getMixedSizes(), op.getMixedStrides());
   Value newSubview = rewriter.create<SubViewOp>(
       op.getLoc(), newResultType, conversionOp.getOperand(0),
       op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index bc0dd034f6385..c475d92e0658e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
     // `subview(old_op)` is replaced by a new `subview(val)`.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(subviewUse);
-    Type newType = memref::SubViewOp::inferRankReducedResultType(
+    MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
         subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
         subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
         subviewUse.getStaticStrides());
     Value newSubview = rewriter.create<memref::SubViewOp>(
-        subviewUse->getLoc(), cast<MemRefType>(newType), val,
-        subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
-        subviewUse.getMixedStrides());
+        subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
+        subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
 
     // Ouch recursion ... is this really necessary?
     replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
   for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
     sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
   // Strides is [1, 1 ... 1 ].
-  auto dstMemref =
-      cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
-          originalShape, mbMemRefType, offsets, sizes, strides));
+  MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
+      originalShape, mbMemRefType, offsets, sizes, strides);
   Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
                                                      offsets, sizes, strides);
   LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ed3ba321b37ab..81404fa664cd4 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -407,10 +407,10 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
-    return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
+    return memref::SubViewOp::inferRankReducedResultType(
         extractSliceOp.getType().getShape(),
         llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
-        mixedStrides));
+        mixedStrides);
   }
 };
 
@@ -692,10 +692,10 @@ struct InsertSliceOpInterface
 
     // Take a subview of the destination buffer.
     auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
-    auto subviewMemRefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+    MemRefType subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
             insertSliceOp.getSourceType().getShape(), dstMemrefType,
-            mixedOffsets, mixedSizes, mixedStrides));
+            mixedOffsets, mixedSizes, mixedStrides);
     Value subView = rewriter.create<memref::SubViewOp>(
         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
         mixedStrides);
@@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface
 
     // Take a subview of the destination buffer.
     auto destBufferType = cast<MemRefType>(destBuffer->getType());
-    auto subviewMemRefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+    MemRefType subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
             parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
             parallelInsertSliceOp.getMixedOffsets(),
             parallelInsertSliceOp.getMixedSizes(),
-            parallelInsertSliceOp.getMixedStrides()));
+            parallelInsertSliceOp.getMixedStrides());
     Value subview = rewriter.create<memref::SubViewOp>(
         parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
         parallelInsertSliceOp.getMixedOffsets(),

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 5871d6dd5b3e6..f13e54901f690 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType,
                                ArrayRef<OpFoldResult> sizes,
                                ArrayRef<OpFoldResult> strides) {
   auto targetShape = getReducedShape(sizes);
-  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
+  MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
       targetShape, inputType, offsets, sizes, strides);
-  return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout();
+  return rankReducedType.canonicalizeStridedLayout();
 }
 
 /// Creates a rank-reducing memref.subview op that drops unit dims from its

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 47fca8e72b573..dc46ed17a374d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1326,10 +1326,9 @@ class DropInnerMostUnitDimsTransferRead
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(srcType.getRank(),
                                       rewriter.getIndexAttr(1));
-    auto resultMemrefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
-            srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
-            strides));
+    MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+        srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+        strides);
     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
         readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
@@ -1417,10 +1416,9 @@ class DropInnerMostUnitDimsTransferWrite
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(srcType.getRank(),
                                       rewriter.getIndexAttr(1));
-    auto resultMemrefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
-            srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
-            strides));
+    MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+        srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+        strides);
     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
         writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
 


        


More information about the Mlir-commits mailing list