[Mlir-commits] [mlir] 44485fc - [mlir] Prevent assertion failure in DropUnitDims
Tres Popp
llvmlistbot at llvm.org
Tue Aug 31 03:16:49 PDT 2021
Author: Tres Popp
Date: 2021-08-31T12:15:13+02:00
New Revision: 44485fcd97490db57df49796d0566a3ce5e23f4c
URL: https://github.com/llvm/llvm-project/commit/44485fcd97490db57df49796d0566a3ce5e23f4c
DIFF: https://github.com/llvm/llvm-project/commit/44485fcd97490db57df49796d0566a3ce5e23f4c.diff
LOG: [mlir] Prevent assertion failure in DropUnitDims
Don't assert fail on strided memrefs when dropping unit dims.
Instead just leave them unchanged.
Differential Revision: https://reviews.llvm.org/D108205
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index de847d1f0fe7..e23a58e50cf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -267,9 +267,9 @@ struct UnitExtentReplacementInfo {
/// - modified index map that can be used to access the replaced result/operand
/// - the reassociation that converts from the original tensor type to the
/// modified tensor type.
-static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
- OpOperand *opOperand,
- MLIRContext *context) {
+static llvm::Optional<UnitExtentReplacementInfo>
+replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
+ MLIRContext *context) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
@@ -284,6 +284,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
return shape[dim] == 1 && exprs[dim] == zeroExpr;
};
+ // Early return for memrefs with affine maps to represent that we will always
+ // leave them unchanged.
+ Type actualType = opOperand->get().getType();
+ if (auto memref = actualType.dyn_cast<MemRefType>()) {
+ if (!memref.getAffineMaps().empty())
+ return llvm::None;
+ }
+
int64_t dim = 0;
// Fold dimensions that are unit-extent at the beginning of the tensor.
while (dim < origRank && isUnitExtent(dim))
@@ -302,8 +310,8 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
reassociations.clear();
++dim;
}
+
// Compute the tensor or scalar replacement type.
- Type actualType = opOperand->get().getType();
Type elementType = getElementTypeOrSelf(opOperand->get());
Type replacementType;
if (elementType == opOperand->get().getType()) {
@@ -311,8 +319,6 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
} else if (actualType.isa<RankedTensorType>()) {
replacementType = RankedTensorType::get(newShape, elementType);
} else if (actualType.isa<MemRefType>()) {
- assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
- "unsupported strided memrefs");
replacementType = MemRefType::get(newShape, elementType);
}
assert(replacementType && "unsupported shaped type");
@@ -390,12 +396,28 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
SmallVector<Type> newInputOutputTypes;
bool doCanonicalization = false;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- UnitExtentReplacementInfo replacementInfo =
- replaceUnitExtents(genericOp, opOperand, context);
- reassociationMaps.push_back(replacementInfo.reassociation);
- newIndexingMaps.push_back(replacementInfo.indexMap);
- newInputOutputTypes.push_back(replacementInfo.type);
- doCanonicalization |= replacementInfo.type != opOperand->get().getType();
+ auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+ if (replacementInfo) {
+ reassociationMaps.push_back(replacementInfo->reassociation);
+ newIndexingMaps.push_back(replacementInfo->indexMap);
+ newInputOutputTypes.push_back(replacementInfo->type);
+ doCanonicalization |=
+ replacementInfo->type != opOperand->get().getType();
+ } else {
+ // If replaceUnitExtents cannot handle this case, maintain the same
+ // type, indexing map, and create a set of mappings representing an
+ // identity matrix.
+ newInputOutputTypes.push_back(opOperand->get().getType());
+ newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
+ int64_t origRank = genericOp.getRank(opOperand);
+ auto maps = llvm::to_vector<8>(llvm::map_range(
+ llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
+ return AffineMapAttr::get(
+ AffineMap::get(origRank, /*symbolCount = */ 0,
+ getAffineDimExpr(dim, context), context));
+ }));
+ reassociationMaps.push_back(ArrayAttr::get(context, maps));
+ }
}
// If the indexing maps of the result operation are not invertible (i.e. not
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index a4357b6e4cd1..5385083f470f 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -750,4 +750,50 @@ func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32>
// CHECK: return %[[INIT:.+]] : memref<1xf32>
+// -----
+// Test that nothing changes and no assertions are fired for memrefs with affine
+// maps while still changing the other operations.
+
+#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
+
+#accesses = [
+ affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> ()>,
+ affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = #accesses,
+ library_call = "some_external_func"
+}
+
+func @input_stays_same(%arg0 : memref<?x1x?xf32, #map0>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
+ linalg.generic #trait
+ ins(%arg0, %arg1 : memref<?x1x?xf32, #map0>, f32)
+ outs(%shape : memref<?x1x?x1x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
+ linalg.yield %arg3 : f32
+ }
+ return %shape : memref<?x1x?x1x?xf32>
+}
+// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 + d2)>
+// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
+// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
+// CHECK: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: builtin.func @input_stays_same(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x1x?xf32, #[[MAP0]]>,
+// CHECK-SAME: %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)
+// CHECK-SAME -> memref<?x1x?x1x?xf32> {
+// CHECK: %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]]
+// CHECK-SAME: : memref<?x1x?x1x?xf32> into memref<?x?x?xf32>
+// CHECK: linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, #[[MAP0]]>, f32)
+// CHECK-SAME: outs(%[[OUT]] : memref<?x?x?xf32>) {
+// CHECK: ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32): // no predecessors
+// CHECK: linalg.yield %[[ARG]] : f32
+// CHECK: }
+// CHECK: return %[[ARG2]] : memref<?x1x?x1x?xf32>
More information about the Mlir-commits
mailing list