[Mlir-commits] [mlir] 44485fc - [mlir] Prevent assertion failure in DropUnitDims

Tres Popp llvmlistbot at llvm.org
Tue Aug 31 03:16:49 PDT 2021


Author: Tres Popp
Date: 2021-08-31T12:15:13+02:00
New Revision: 44485fcd97490db57df49796d0566a3ce5e23f4c

URL: https://github.com/llvm/llvm-project/commit/44485fcd97490db57df49796d0566a3ce5e23f4c
DIFF: https://github.com/llvm/llvm-project/commit/44485fcd97490db57df49796d0566a3ce5e23f4c.diff

LOG: [mlir] Prevent assertion failure in DropUnitDims

Don't assert fail on strided memrefs when dropping unit dims.
Instead just leave them unchanged.

Differential Revision: https://reviews.llvm.org/D108205

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index de847d1f0fe7..e23a58e50cf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -267,9 +267,9 @@ struct UnitExtentReplacementInfo {
 /// - modified index map that can be used to access the replaced result/operand
 /// - the reassociation that converts from the original tensor type to the
 ///   modified tensor type.
-static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
-                                                    OpOperand *opOperand,
-                                                    MLIRContext *context) {
+static llvm::Optional<UnitExtentReplacementInfo>
+replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
+                   MLIRContext *context) {
   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
   ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
   ArrayRef<AffineExpr> exprs = indexingMap.getResults();
@@ -284,6 +284,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     return shape[dim] == 1 && exprs[dim] == zeroExpr;
   };
 
+  // Early return for memrefs with affine maps to represent that we will always
+  // leave them unchanged.
+  Type actualType = opOperand->get().getType();
+  if (auto memref = actualType.dyn_cast<MemRefType>()) {
+    if (!memref.getAffineMaps().empty())
+      return llvm::None;
+  }
+
   int64_t dim = 0;
   // Fold dimensions that are unit-extent at the beginning of the tensor.
   while (dim < origRank && isUnitExtent(dim))
@@ -302,8 +310,8 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
     reassociations.clear();
     ++dim;
   }
+
   // Compute the tensor or scalar replacement type.
-  Type actualType = opOperand->get().getType();
   Type elementType = getElementTypeOrSelf(opOperand->get());
   Type replacementType;
   if (elementType == opOperand->get().getType()) {
@@ -311,8 +319,6 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
   } else if (actualType.isa<RankedTensorType>()) {
     replacementType = RankedTensorType::get(newShape, elementType);
   } else if (actualType.isa<MemRefType>()) {
-    assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
-           "unsupported strided memrefs");
     replacementType = MemRefType::get(newShape, elementType);
   }
   assert(replacementType && "unsupported shaped type");
@@ -390,12 +396,28 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
     SmallVector<Type> newInputOutputTypes;
     bool doCanonicalization = false;
     for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      UnitExtentReplacementInfo replacementInfo =
-          replaceUnitExtents(genericOp, opOperand, context);
-      reassociationMaps.push_back(replacementInfo.reassociation);
-      newIndexingMaps.push_back(replacementInfo.indexMap);
-      newInputOutputTypes.push_back(replacementInfo.type);
-      doCanonicalization |= replacementInfo.type != opOperand->get().getType();
+      auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+      if (replacementInfo) {
+        reassociationMaps.push_back(replacementInfo->reassociation);
+        newIndexingMaps.push_back(replacementInfo->indexMap);
+        newInputOutputTypes.push_back(replacementInfo->type);
+        doCanonicalization |=
+            replacementInfo->type != opOperand->get().getType();
+      } else {
+        // If replaceUnitExtents cannot handle this case, maintain the same
+        // type, indexing map, and create a set of mappings representing an
+        // identity matrix.
+        newInputOutputTypes.push_back(opOperand->get().getType());
+        newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+        int64_t origRank = genericOp.getRank(opOperand);
+        auto maps = llvm::to_vector<8>(llvm::map_range(
+            llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
+              return AffineMapAttr::get(
+                  AffineMap::get(origRank, /*symbolCount = */ 0,
+                                 getAffineDimExpr(dim, context), context));
+            }));
+        reassociationMaps.push_back(ArrayAttr::get(context, maps));
+      }
     }
 
     // If the indexing maps of the result operation are not invertible (i.e. not

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index a4357b6e4cd1..5385083f470f 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -750,4 +750,50 @@ func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32>
 //       CHECK: return %[[INIT:.+]] : memref<1xf32>
 
 
+// -----
+// Test that nothing changes and no assertions are fired for memrefs with affine
+// maps while still changing the other operations.
+
+#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
+
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> ()>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
+  linalg.generic #trait
+     ins(%arg0, %arg1 : memref<?x1x?xf32, #map0>, f32)
+    outs(%shape : memref<?x1x?x1x?xf32>) {
+       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+         linalg.yield %arg3 : f32
+       }
+  return %shape : memref<?x1x?x1x?xf32>
+}
 
+// CHECK:     #[[MAP0:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
+// CHECK:     #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+// CHECK:     #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
+// CHECK:     #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK:     builtin.func @input_stays_same(
+// CHECK-SAME:  %[[ARG0:.*]]: memref<?x1x?xf32, #[[MAP0]]>,
+// CHECK-SAME:  %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)
+// CHECK-SAME   -> memref<?x1x?x1x?xf32> {
+// CHECK:      %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]]
+// CHECK-SAME:   : memref<?x1x?x1x?xf32> into memref<?x?x?xf32>
+// CHECK:      linalg.generic
+// CHECK-SAME:   {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]],
+// CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:   ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, #[[MAP0]]>, f32)
+// CHECK-SAME:   outs(%[[OUT]] : memref<?x?x?xf32>) {
+// CHECK:      ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // no predecessors
+// CHECK:       linalg.yield %[[ARG]] : f32
+// CHECK:      }
+// CHECK:      return %[[ARG2]] : memref<?x1x?x1x?xf32>


        


More information about the Mlir-commits mailing list