[Mlir-commits] [mlir] [MLIR][NFC] Return MemRefType in memref.subview return type inference functions (PR #120024)

Tomás Longeri llvmlistbot at llvm.org
Sun Dec 15 15:28:01 PST 2024


https://github.com/tlongeri created https://github.com/llvm/llvm-project/pull/120024

None

>From 556c406e32ecc8fbf43305f89fe36aed383ec801 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= <tlongeri at google.com>
Date: Sun, 15 Dec 2024 22:55:05 +0000
Subject: [PATCH] [MLIR][NFC] Return MemRefType in memref.subview return type
 inference functions

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 36 +++++------
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 61 +++++++++----------
 .../Transforms/IndependenceTransforms.cpp     |  4 +-
 .../Dialect/MemRef/Transforms/MultiBuffer.cpp | 12 ++--
 .../BufferizableOpInterfaceImpl.cpp           | 16 ++---
 .../Transforms/VectorTransferOpTransforms.cpp |  4 +-
 .../Vector/Transforms/VectorTransforms.cpp    | 14 ++---
 7 files changed, 70 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index a0d8d34f38237a..4e31bb153c5e7e 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -2079,14 +2079,14 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     /// A subview result type can be fully inferred from the source type and the
     /// static representation of offsets, sizes and strides. Special sentinels
     /// encode the dynamic case.
-    static Type inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<int64_t> staticOffsets,
-                                ArrayRef<int64_t> staticSizes,
-                                ArrayRef<int64_t> staticStrides);
-    static Type inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<OpFoldResult> staticOffsets,
-                                ArrayRef<OpFoldResult> staticSizes,
-                                ArrayRef<OpFoldResult> staticStrides);
+    static MemRefType inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<int64_t> staticOffsets,
+                                      ArrayRef<int64_t> staticSizes,
+                                      ArrayRef<int64_t> staticStrides);
+    static MemRefType inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<OpFoldResult> staticOffsets,
+                                      ArrayRef<OpFoldResult> staticSizes,
+                                      ArrayRef<OpFoldResult> staticStrides);
 
     /// A rank-reducing result type can be inferred from the desired result
     /// shape. Only the layout map is inferred.
@@ -2095,16 +2095,16 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     /// and the desired sizes. In case there are more "ones" among the sizes
     /// than the difference in source/result rank, it is not clear which dims of
     /// size one should be dropped.
-    static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceMemRefType,
-                                           ArrayRef<int64_t> staticOffsets,
-                                           ArrayRef<int64_t> staticSizes,
-                                           ArrayRef<int64_t> staticStrides);
-    static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceMemRefType,
-                                           ArrayRef<OpFoldResult> staticOffsets,
-                                           ArrayRef<OpFoldResult> staticSizes,
-                                           ArrayRef<OpFoldResult> staticStrides);
+    static MemRefType inferRankReducedResultType(
+        ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+        ArrayRef<int64_t> staticOffsets,
+        ArrayRef<int64_t> staticSizes,
+        ArrayRef<int64_t> staticStrides);
+    static MemRefType inferRankReducedResultType(
+        ArrayRef<int64_t> resultShape, MemRefType sourceMemRefType,
+        ArrayRef<OpFoldResult> staticOffsets,
+        ArrayRef<OpFoldResult> staticSizes,
+        ArrayRef<OpFoldResult> staticStrides);
 
     /// Return the expected rank of each of the`static_offsets`, `static_sizes`
     /// and `static_strides` attributes.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 2219505c9b802f..12768f06fb1b0e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2697,10 +2697,10 @@ void SubViewOp::getAsmResultNames(
 /// A subview result type can be fully inferred from the source type and the
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<int64_t> staticOffsets,
-                                ArrayRef<int64_t> staticSizes,
-                                ArrayRef<int64_t> staticStrides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<int64_t> staticOffsets,
+                                      ArrayRef<int64_t> staticSizes,
+                                      ArrayRef<int64_t> staticStrides) {
   unsigned rank = sourceMemRefType.getRank();
   (void)rank;
   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
@@ -2739,10 +2739,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
                          sourceMemRefType.getMemorySpace());
 }
 
-Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
-                                ArrayRef<OpFoldResult> offsets,
-                                ArrayRef<OpFoldResult> sizes,
-                                ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
+                                      ArrayRef<OpFoldResult> offsets,
+                                      ArrayRef<OpFoldResult> sizes,
+                                      ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2758,13 +2758,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
                                     staticSizes, staticStrides);
 }
 
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceRankedTensorType,
-                                           ArrayRef<int64_t> offsets,
-                                           ArrayRef<int64_t> sizes,
-                                           ArrayRef<int64_t> strides) {
-  auto inferredType = llvm::cast<MemRefType>(
-      inferResultType(sourceRankedTensorType, offsets, sizes, strides));
+MemRefType SubViewOp::inferRankReducedResultType(
+    ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+    ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
+    ArrayRef<int64_t> strides) {
+  MemRefType inferredType =
+      inferResultType(sourceRankedTensorType, offsets, sizes, strides);
   assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
          "expected ");
   if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
@@ -2790,11 +2789,10 @@ Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
                          inferredType.getMemorySpace());
 }
 
-Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
-                                           MemRefType sourceRankedTensorType,
-                                           ArrayRef<OpFoldResult> offsets,
-                                           ArrayRef<OpFoldResult> sizes,
-                                           ArrayRef<OpFoldResult> strides) {
+MemRefType SubViewOp::inferRankReducedResultType(
+    ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
+    ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+    ArrayRef<OpFoldResult> strides) {
   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -2821,8 +2819,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
   auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
   // Structuring implementation this way avoids duplication between builders.
   if (!resultType) {
-    resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
-        sourceMemRefType, staticOffsets, staticSizes, staticStrides));
+    resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
+                                            staticSizes, staticStrides);
   }
   result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
@@ -2987,8 +2985,8 @@ LogicalResult SubViewOp::verify() {
 
   // Compute the expected result type, assuming that there are no rank
   // reductions.
-  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
-      baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
+  MemRefType expectedType = SubViewOp::inferResultType(
+      baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
 
   // Verify all properties of a shaped type: rank, element type and dimension
   // sizes. This takes into account potential rank reductions.
@@ -3070,8 +3068,8 @@ static MemRefType getCanonicalSubViewResultType(
     MemRefType currentResultType, MemRefType currentSourceType,
     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
-  auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
-      sourceType, mixedOffsets, mixedSizes, mixedStrides));
+  MemRefType nonRankReducedType = SubViewOp::inferResultType(
+      sourceType, mixedOffsets, mixedSizes, mixedStrides);
   FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
       currentSourceType, currentResultType, mixedSizes);
   if (failed(unusedDims))
@@ -3105,9 +3103,8 @@ Value mlir::memref::createCanonicalRankReducingSubViewOp(
   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
   SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
-  auto targetType =
-      llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
-          targetShape, memrefType, offsets, sizes, strides));
+  MemRefType targetType = SubViewOp::inferRankReducedResultType(
+      targetShape, memrefType, offsets, sizes, strides);
   return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
                                            sizes, strides);
 }
@@ -3251,11 +3248,11 @@ struct SubViewReturnTypeCanonicalizer {
                         ArrayRef<OpFoldResult> mixedSizes,
                         ArrayRef<OpFoldResult> mixedStrides) {
     // Infer a memref type without taking into account any rank reductions.
-    auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
-                                            mixedSizes, mixedStrides);
+    MemRefType resTy = SubViewOp::inferResultType(
+        op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
     if (!resTy)
       return {};
-    MemRefType nonReducedType = cast<MemRefType>(resTy);
+    MemRefType nonReducedType = resTy;
 
     // Directly return the non-rank reduced type if there are no dropped dims.
     llvm::SmallBitVector droppedDims = op.getDroppedDims();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 1f06318cbd60e0..8ffea5a7839980 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -70,9 +70,9 @@ propagateSubViewOp(RewriterBase &rewriter,
                    UnrealizedConversionCastOp conversionOp, SubViewOp op) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(op);
-  auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+  MemRefType newResultType = SubViewOp::inferRankReducedResultType(
       op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
-      op.getMixedSizes(), op.getMixedStrides()));
+      op.getMixedSizes(), op.getMixedStrides());
   Value newSubview = rewriter.create<SubViewOp>(
       op.getLoc(), newResultType, conversionOp.getOperand(0),
       op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index bc0dd034f63851..c475d92e0658e5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -60,14 +60,13 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
     // `subview(old_op)` is replaced by a new `subview(val)`.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(subviewUse);
-    Type newType = memref::SubViewOp::inferRankReducedResultType(
+    MemRefType newType = memref::SubViewOp::inferRankReducedResultType(
         subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
         subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
         subviewUse.getStaticStrides());
     Value newSubview = rewriter.create<memref::SubViewOp>(
-        subviewUse->getLoc(), cast<MemRefType>(newType), val,
-        subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
-        subviewUse.getMixedStrides());
+        subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(),
+        subviewUse.getMixedSizes(), subviewUse.getMixedStrides());
 
     // Ouch recursion ... is this really necessary?
     replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
@@ -211,9 +210,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
   for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
     sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
   // Strides is [1, 1 ... 1 ].
-  auto dstMemref =
-      cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
-          originalShape, mbMemRefType, offsets, sizes, strides));
+  MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType(
+      originalShape, mbMemRefType, offsets, sizes, strides);
   Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
                                                      offsets, sizes, strides);
   LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 9797b73f534a96..35862c74c57552 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -407,10 +407,10 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
-    return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
+    return memref::SubViewOp::inferRankReducedResultType(
         extractSliceOp.getType().getShape(),
         llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
-        mixedStrides));
+        mixedStrides);
   }
 };
 
@@ -692,10 +692,10 @@ struct InsertSliceOpInterface
 
     // Take a subview of the destination buffer.
     auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
-    auto subviewMemRefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+    MemRefType subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
             insertSliceOp.getSourceType().getShape(), dstMemrefType,
-            mixedOffsets, mixedSizes, mixedStrides));
+            mixedOffsets, mixedSizes, mixedStrides);
     Value subView = rewriter.create<memref::SubViewOp>(
         loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
         mixedStrides);
@@ -960,12 +960,12 @@ struct ParallelInsertSliceOpInterface
 
     // Take a subview of the destination buffer.
     auto destBufferType = cast<MemRefType>(destBuffer->getType());
-    auto subviewMemRefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+    MemRefType subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
             parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
             parallelInsertSliceOp.getMixedOffsets(),
             parallelInsertSliceOp.getMixedSizes(),
-            parallelInsertSliceOp.getMixedStrides()));
+            parallelInsertSliceOp.getMixedStrides());
     Value subview = rewriter.create<memref::SubViewOp>(
         parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
         parallelInsertSliceOp.getMixedOffsets(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index bd5f06a3b46d42..b124ea32af1b32 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -265,9 +265,9 @@ static MemRefType dropUnitDims(MemRefType inputType,
                                ArrayRef<OpFoldResult> sizes,
                                ArrayRef<OpFoldResult> strides) {
   auto targetShape = getReducedShape(sizes);
-  Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
+  MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
       targetShape, inputType, offsets, sizes, strides);
-  return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
+  return canonicalizeStridedLayout(rankReducedType);
 }
 
 /// Creates a rank-reducing memref.subview op that drops unit dims from its
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 20cd9cba6909a6..3f3e9ae9df2865 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1319,10 +1319,9 @@ class DropInnerMostUnitDimsTransferRead
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(srcType.getRank(),
                                       rewriter.getIndexAttr(1));
-    auto resultMemrefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
-            srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
-            strides));
+    MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+        srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+        strides);
     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
         readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
@@ -1410,10 +1409,9 @@ class DropInnerMostUnitDimsTransferWrite
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(srcType.getRank(),
                                       rewriter.getIndexAttr(1));
-    auto resultMemrefType =
-        cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
-            srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
-            strides));
+    MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
+        srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
+        strides);
     ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
         writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
 



More information about the Mlir-commits mailing list