[Mlir-commits] [mlir] bc38673 - [mlir] Make sure linearizeCollapsedDims doesn't drop input map dims
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 30 22:38:05 PST 2021
Author: MaheshRavishankar
Date: 2021-11-30T22:37:53-08:00
New Revision: bc38673e4de50b995f4bc46d1a4b0ad95bef2356
URL: https://github.com/llvm/llvm-project/commit/bc38673e4de50b995f4bc46d1a4b0ad95bef2356
DIFF: https://github.com/llvm/llvm-project/commit/bc38673e4de50b995f4bc46d1a4b0ad95bef2356.diff
LOG: [mlir] Make sure linearizeCollapsedDims doesn't drop input map dims
The new affine map generated by linearizeCollapsedDims should not drop
dimensions. We need to make sure we create a map with at least as many
dimensions as the source map. This prevents
FoldProducerReshapeOpByLinearization from generating invalid IR.
This solves regression in IREE due to https://github.com/llvm/llvm-project/commit/e4e4da86aff5606ef792d987a3ec85639219228c
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D114838
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6faf6857c2aaa..959dbb4d5b415 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -385,7 +385,15 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
resultExprs.push_back(linearizedExpr);
}
- return AffineMap::inferFromExprList({resultExprs}).front();
+ // The new affine map cannot drop unused dimension but some new symbols may
+ // have been added. Create a map with at least as many dimensions/symbols as
+ // the original affine map.
+ int64_t maxDim = -1;
+ int64_t maxSym = -1;
+ getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
+ unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
+ unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
+ return AffineMap::get(numDims, numSyms, resultExprs, context);
}
// TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a
diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
index f563fe79474fd..512f7545a5eae 100644
--- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir
@@ -199,3 +199,31 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
// CHECK: %[[RESULT:.+]] = linalg.tensor_collapse_shape %[[NOFUSE]]
// CHECK: return %[[RESULT]]
+
+
+// -----
+
+func @generic_op_permultation_reshape_consumer_fusion_unused_dim(%arg0 : tensor<6x1xf32>) -> tensor<6xi32> {
+ %0 = linalg.init_tensor [6, 1] : tensor<6x1xi32>
+ %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<6x1xf32>) outs(%0 : tensor<6x1xi32>) {
+ ^bb0(%arg3: f32, %arg4: i32): // no predecessors
+ %5 = arith.fptosi %arg3 : f32 to i32
+ linalg.yield %5 : i32
+ } -> tensor<6x1xi32>
+ %6 = linalg.tensor_collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32>
+ return %6 : tensor<6xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK: func @generic_op_permultation_reshape_consumer_fusion_unused_dim
+// CHECK-SAME: %[[ARG0:.+]]: tensor<6x1xf32>
+// CHECK: %[[T0:.+]] = linalg.init_tensor [6, 1]
+// CHECK: %[[T1:.+]] = linalg.tensor_collapse_shape %[[T0]]
+// CHECK-SAME: [0, 1]
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: ins(%[[ARG0]] : tensor<6x1xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<6xi32>)
More information about the Mlir-commits
mailing list