[Mlir-commits] [mlir] 8ed66cb - [mlir][memref] Fix collapsed shape ops memref.cast folding with changed type

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jul 28 03:20:05 PDT 2021


Author: Yi Zhang
Date: 2021-07-28T10:19:20Z
New Revision: 8ed66cb88b7b00d7e9a96f2030e7ec343cfe2c6a

URL: https://github.com/llvm/llvm-project/commit/8ed66cb88b7b00d7e9a96f2030e7ec343cfe2c6a
DIFF: https://github.com/llvm/llvm-project/commit/8ed66cb88b7b00d7e9a96f2030e7ec343cfe2c6a.diff

LOG: [mlir][memref] Fix collapsed shape ops memref.cast folding with changed type

`memref.collapse_shape` has verification logic to make sure
result dim must be static if all the collapsing src dims are static.
Cast folding might add more static information for the src operand
of `memref.collapse_shape` which might change a valid collapsing
operation to be invalid. Add `CollapseShapeOpMemRefCastFolder` pattern
to fix this.

Minor changes to `convertReassociationIndicesToExprs` to use `context`
instead of `builder` to avoid extra steps to construct temporary
builders.

Reviewed By: nicolasvasilache, mravishankar

Differential Revision: https://reviews.llvm.org/D106670

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 8cf9dc36e2220..0c79d26a65cf2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -225,8 +225,8 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
       return getResult().getType().cast<RankedTensorType>();
     }
 
-    // Infer the shape of the result tensor given the static shapes
-    // and element type of the result tensor.
+    // Infer the shape of the result tensor given the type of the source tensor
+    // and paddings.
     static RankedTensorType inferResultType(RankedTensorType sourceType,
                                 ArrayRef<int64_t> staticLow,
                                 ArrayRef<int64_t> staticHigh);

diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 57350c5b14066..55ca2b8814099 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -47,7 +47,7 @@ Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
 
 /// Convert reassociation indices to affine expressions.
 SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
-    OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices);
+    MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
 
 /// Constructs affine maps out of Array<Array<AffineExpr>>.
 SmallVector<AffineMap, 4>

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9ef9a05c1284e..9096e37de082b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1147,16 +1147,16 @@ SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() {
 }
 SmallVector<ReassociationExprs, 4>
 TensorCollapseShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 SmallVector<AffineMap, 4> TensorExpandShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
 SmallVector<ReassociationExprs, 4>
 TensorExpandShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 
 /// For reshape op compute the shape at dimension `dimIndex` of the output in
@@ -1317,7 +1317,7 @@ void mlir::linalg::TensorCollapseShapeOp::build(
   auto resultType = computeTensorReshapeCollapsedType(
       src.getType().cast<RankedTensorType>(),
       getSymbolLessAffineMaps(
-          convertReassociationIndicesToExprs(b, reassociation)));
+          convertReassociationIndicesToExprs(b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
@@ -1330,7 +1330,7 @@ void mlir::linalg::TensorExpandShapeOp::build(
   auto resultType = computeTensorReshapeCollapsedType(
       src.getType().cast<RankedTensorType>(),
       getSymbolLessAffineMaps(
-          convertReassociationIndicesToExprs(b, reassociation)));
+          convertReassociationIndicesToExprs(b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b65a6833f1fd9..2d2d9ae222e34 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1316,16 +1316,16 @@ SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 
 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
-  OpBuilder b(this->getContext());
-  return convertReassociationIndicesToExprs(b, getReassociationIndices());
+  return convertReassociationIndicesToExprs(getContext(),
+                                            getReassociationIndices());
 }
 
 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
@@ -1427,8 +1427,8 @@ void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                           ArrayRef<NamedAttribute> attrs) {
   auto memRefType = src.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(
-      memRefType, getSymbolLessAffineMaps(
-                      convertReassociationIndicesToExprs(b, reassociation)));
+      memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                      b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
@@ -1439,8 +1439,8 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
                             ArrayRef<NamedAttribute> attrs) {
   auto memRefType = src.getType().cast<MemRefType>();
   auto resultType = computeReshapeCollapsedType(
-      memRefType, getSymbolLessAffineMaps(
-                      convertReassociationIndicesToExprs(b, reassociation)));
+      memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+                      b.getContext(), reassociation)));
   build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
@@ -1475,10 +1475,41 @@ static LogicalResult verify(CollapseShapeOp op) {
   return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
 }
 
+struct CollapseShapeOpMemRefCastFolder
+    : public OpRewritePattern<CollapseShapeOp> {
+public:
+  using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CollapseShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto cast = op.getOperand().getDefiningOp<CastOp>();
+    if (!cast)
+      return failure();
+
+    if (!CastOp::canFoldIntoConsumerOp(cast))
+      return failure();
+
+    Type newResultType = computeReshapeCollapsedType(
+        cast.getOperand().getType().cast<MemRefType>(),
+        op.getReassociationMaps());
+
+    if (newResultType == op.getResultType()) {
+      rewriter.updateRootInPlace(
+          op, [&]() { op.srcMutable().assign(cast.source()); });
+    } else {
+      Value newOp = rewriter.create<CollapseShapeOp>(
+          op->getLoc(), cast.source(), op.getReassociationIndices());
+      rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+    }
+    return success();
+  }
+};
+
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.add<CollapseReshapeOps<CollapseShapeOp>,
-              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>>(context);
+              CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
+              CollapseShapeOpMemRefCastFolder>(context);
 }
 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
   if (succeeded(foldMemRefCast(*this)))
@@ -1486,8 +1517,6 @@ OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
 }
 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
-  if (succeeded(foldMemRefCast(*this)))
-    return getResult();
   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
 }
 

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0c7aa52848b3b..919415a3f13ed 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -183,13 +183,13 @@ Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
 
 SmallVector<SmallVector<AffineExpr, 2>, 2>
 mlir::convertReassociationIndicesToExprs(
-    OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) {
+    MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
   SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
   for (const auto &indices : reassociationIndices) {
     SmallVector<AffineExpr, 2> reassociationMap;
     reassociationMap.reserve(indices.size());
     for (int64_t index : indices)
-      reassociationMap.push_back(b.getAffineDimExpr(index));
+      reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
     reassociationMaps.push_back(std::move(reassociationMap));
   }
   return reassociationMaps;

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index c63994b480dbe..886f21d22021c 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -511,3 +511,31 @@ func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32> {
 }
 // CHECK-LABEL: @fold_memref_reshape_dynamic
 //   CHECK-NOT:   linalg.{{.*}}_shape
+
+// -----
+
+// CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
+// CHECK-SAME:         {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x1xf32> into memref<?x512xf32>
+// CHECK:           %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] :
+// CHECK-SAME:         memref<?x512xf32> to memref<?x?xf32>
+// CHECK:           return %[[DYNAMIC]] : memref<?x?xf32>
+// CHECK:         }
+func @collapse_after_memref_cast_type_change(%arg0 : memref<?x512x1x1xf32>) -> memref<?x?xf32> {
+  %dynamic = memref.cast %arg0: memref<?x512x1x1xf32> to memref<?x?x?x?xf32>
+  %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
+  return %collapsed : memref<?x?xf32>
+}
+
+// CHECK-LABEL:   func @collapse_after_memref_cast(
+// CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x?xf32>) -> memref<?x?xf32> {
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
+// CHECK_SAME:        {{\[\[}}0], [1, 2, 3]] : memref<?x512x1x?xf32> into memref<?x?xf32>
+// CHECK:           return %[[COLLAPSED]] : memref<?x?xf32>
+func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf32> {
+  %dynamic = memref.cast %arg0: memref<?x512x1x?xf32> to memref<?x?x?x?xf32>
+  %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
+  return %collapsed : memref<?x?xf32>
+}
+


        


More information about the Mlir-commits mailing list