[Mlir-commits] [mlir] adfd3c7 - [mlir] Fix memref_cast + subview folder when reducing rank

Thomas Raoux llvmlistbot at llvm.org
Tue Feb 16 12:01:37 PST 2021


Author: Thomas Raoux
Date: 2021-02-16T12:00:59-08:00
New Revision: adfd3c7083f9808d145239153c10f72eece485d8

URL: https://github.com/llvm/llvm-project/commit/adfd3c7083f9808d145239153c10f72eece485d8
DIFF: https://github.com/llvm/llvm-project/commit/adfd3c7083f9808d145239153c10f72eece485d8.diff

LOG: [mlir] Fix memref_cast + subview folder when reducing rank

When the destination of the subview has a lower rank than its source we need to
fix the result type of the new subview op.

Differential Revision: https://reviews.llvm.org/D96804

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 49082912b803d..5582c0bde555c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3058,7 +3058,23 @@ isRankReducedType(Type originalType, Type candidateReducedType,
     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
   else
     candidateLayout = candidateReduced.getAffineMaps().front();
-  if (inferredType != candidateLayout) {
+  assert(inferredType.getNumResults() == 1 &&
+         candidateLayout.getNumResults() == 1);
+  if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
+      inferredType.getNumDims() != candidateLayout.getNumDims()) {
+    if (errMsg) {
+      llvm::raw_string_ostream os(*errMsg);
+      os << "inferred type: " << inferredType;
+    }
+    return SubViewVerificationResult::AffineMapMismatch;
+  }
+  // Check that the 
diff erence of the affine maps simplifies to 0.
+  AffineExpr 
diff Expr =
+      inferredType.getResult(0) - candidateLayout.getResult(0);
+  
diff Expr = simplifyAffineExpr(
diff Expr, inferredType.getNumDims(),
+                                inferredType.getNumSymbols());
+  auto cst = 
diff Expr.dyn_cast<AffineConstantExpr>();
+  if (!(cst && cst.getValue() == 0)) {
     if (errMsg) {
       llvm::raw_string_ostream os(*errMsg);
       os << "inferred type: " << inferredType;
@@ -3344,11 +3360,29 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
     /// the cast source operand type and the SubViewOp static information. This
     /// is the resulting type if the MemRefCastOp were folded.
-    Type resultType = SubViewOp::inferResultType(
-        castOp.source().getType().cast<MemRefType>(),
-        extractFromI64ArrayAttr(subViewOp.static_offsets()),
-        extractFromI64ArrayAttr(subViewOp.static_sizes()),
-        extractFromI64ArrayAttr(subViewOp.static_strides()));
+    auto resultType = SubViewOp::inferResultType(
+                          castOp.source().getType().cast<MemRefType>(),
+                          extractFromI64ArrayAttr(subViewOp.static_offsets()),
+                          extractFromI64ArrayAttr(subViewOp.static_sizes()),
+                          extractFromI64ArrayAttr(subViewOp.static_strides()))
+                          .cast<MemRefType>();
+    uint32_t rankDiff =
+        subViewOp.getSourceType().getRank() - subViewOp.getType().getRank();
+    if (rankDiff > 0) {
+      auto shape = resultType.getShape();
+      auto projectedShape = shape.drop_front(rankDiff);
+      AffineMap map;
+      auto maps = resultType.getAffineMaps();
+      if (!maps.empty() && maps.front()) {
+        auto optionalUnusedDimsMask =
+            computeRankReductionMask(shape, projectedShape);
+        llvm::SmallDenseSet<unsigned> dimsToProject =
+            optionalUnusedDimsMask.getValue();
+        map = getProjectedMap(maps.front(), dimsToProject);
+      }
+      resultType = MemRefType::get(projectedShape, resultType.getElementType(),
+                                   map, resultType.getMemorySpace());
+    }
     Value newSubView = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 7b54938b0c488..c864af8f5747d 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -143,3 +143,17 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
   %1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
   return %1 : memref<?x?x16x32xi8>
 }
+
+// CHECK-LABEL: func @subview_of_memcast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
+//       CHECK:   %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
+func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+  memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
+  %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+  %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
+    memref<?x?x16x32xi8> to
+    memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+  return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+}


        


More information about the Mlir-commits mailing list