[Mlir-commits] [mlir] e3373c6 - [mlir][memref] Fix crash in SubViewReturnTypeCanonicalizer

Matthias Springer llvmlistbot at llvm.org
Fri Aug 25 07:02:05 PDT 2023


Author: Matthias Springer
Date: 2023-08-25T16:01:49+02:00
New Revision: e3373c6c83d3855adb78f1952a3bf0398baf359e

URL: https://github.com/llvm/llvm-project/commit/e3373c6c83d3855adb78f1952a3bf0398baf359e
DIFF: https://github.com/llvm/llvm-project/commit/e3373c6c83d3855adb78f1952a3bf0398baf359e.diff

LOG: [mlir][memref] Fix crash in SubViewReturnTypeCanonicalizer

`SubViewReturnTypeCanonicalizer` is used by `OpWithOffsetSizesAndStridesConstantArgumentFolder`, which folds constant SSA value (dynamic) sizes into static sizes. The previous implementation crashed when a dynamic size was folded into a static `1` dimension, which was then mistaken as a rank reduction.

Differential Revision: https://reviews.llvm.org/D158721

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b2909962027de2..e1b8dd62450a77 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -31,23 +31,17 @@ namespace {
 namespace saturated_arith {
 struct Wrapper {
   static Wrapper stride(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
-                                                    : Wrapper{false, v};
+    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
   }
   static Wrapper offset(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
-                                                    : Wrapper{false, v};
+    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
   }
   static Wrapper size(int64_t v) {
     return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
   }
-  int64_t asOffset() {
-    return saturated ? ShapedType::kDynamic : v;
-  }
+  int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
   int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
-  int64_t asStride() {
-    return saturated ? ShapedType::kDynamic : v;
-  }
+  int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
   bool operator==(Wrapper other) {
     return (saturated && other.saturated) ||
            (!saturated && !other.saturated && v == other.v);
@@ -731,8 +725,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
     auto ss = std::get<0>(it), st = std::get<1>(it);
     if (ss != st)
-      if (ShapedType::isDynamic(ss) &&
-          !ShapedType::isDynamic(st))
+      if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
         return false;
   }
 
@@ -765,8 +758,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
       // same. They are also compatible if either one is dynamic (see
       // description of MemRefCastOp for details).
       auto checkCompatible = [](int64_t a, int64_t b) {
-        return (ShapedType::isDynamic(a) ||
-                ShapedType::isDynamic(b) || a == b);
+        return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
       };
       if (!checkCompatible(aOffset, bOffset))
         return false;
@@ -1889,8 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
   // Match offset in result memref type and in static_offsets attribute.
   int64_t expectedOffset = getStaticOffsets().front();
   if (!ShapedType::isDynamic(resultOffset) &&
-      !ShapedType::isDynamic(expectedOffset) &&
-      resultOffset != expectedOffset)
+      !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
     return emitError("expected result type with offset = ")
            << expectedOffset << " instead of " << resultOffset;
 
@@ -2944,18 +2935,6 @@ static MemRefType getCanonicalSubViewResultType(
                          nonRankReducedType.getMemorySpace());
 }
 
-/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
-/// to deduce the result type. Additionally, reduce the rank of the inferred
-/// result type if `currentResultType` is lower rank than `sourceType`.
-static MemRefType getCanonicalSubViewResultType(
-    MemRefType currentResultType, MemRefType sourceType,
-    ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
-    ArrayRef<OpFoldResult> mixedStrides) {
-  return getCanonicalSubViewResultType(currentResultType, sourceType,
-                                       sourceType, mixedOffsets, mixedSizes,
-                                       mixedStrides);
-}
-
 Value mlir::memref::createCanonicalRankReducingSubViewOp(
     OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
   auto memrefType = llvm::cast<MemRefType>(memref.getType());
@@ -3108,9 +3087,32 @@ struct SubViewReturnTypeCanonicalizer {
   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
                         ArrayRef<OpFoldResult> mixedSizes,
                         ArrayRef<OpFoldResult> mixedStrides) {
-    return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
-                                         mixedOffsets, mixedSizes,
-                                         mixedStrides);
+    // Infer a memref type without taking into account any rank reductions.
+    MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
+        op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));
+
+    // Directly return the non-rank reduced type if there are no dropped dims.
+    llvm::SmallBitVector droppedDims = op.getDroppedDims();
+    if (droppedDims.empty())
+      return nonReducedType;
+
+    // Take the strides and offset from the non-rank reduced type.
+    auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
+
+    // Drop dims from shape and strides.
+    SmallVector<int64_t> targetShape;
+    SmallVector<int64_t> targetStrides;
+    for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
+      if (droppedDims.test(i))
+        continue;
+      targetStrides.push_back(nonReducedStrides[i]);
+      targetShape.push_back(nonReducedType.getDimSize(i));
+    }
+
+    return MemRefType::get(targetShape, nonReducedType.getElementType(),
+                           StridedLayoutAttr::get(nonReducedType.getContext(),
+                                                  offset, targetStrides),
+                           nonReducedType.getMemorySpace());
   }
 };
 

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b65426cad30b6d..df66705e83e0e2 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -931,7 +931,7 @@ func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32
 
 // -----
 
-// CHECK-lABEL: func @ub_negative_alloc_size
+// CHECK-LABEL: func private @ub_negative_alloc_size
 func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
   %idx1 = index.constant 1
   %c-2 = arith.constant -2 : index
@@ -940,3 +940,18 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
   %alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1>
   return %alloc : memref<?x?x?xi1>
 }
+
+// -----
+
+// CHECK-LABEL: func @subview_rank_reduction(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
+func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
+    -> memref<?x?xf32, strided<[384, 1], offset: ?>> {
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
+  // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
+  %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
+      : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
+  // CHECK: return %[[cast]]
+  return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
+}


        


More information about the Mlir-commits mailing list