[Mlir-commits] [mlir] adfd3c7 - [mlir] Fix memref_cast + subview folder when reducing rank
Thomas Raoux
llvmlistbot at llvm.org
Tue Feb 16 12:01:37 PST 2021
Author: Thomas Raoux
Date: 2021-02-16T12:00:59-08:00
New Revision: adfd3c7083f9808d145239153c10f72eece485d8
URL: https://github.com/llvm/llvm-project/commit/adfd3c7083f9808d145239153c10f72eece485d8
DIFF: https://github.com/llvm/llvm-project/commit/adfd3c7083f9808d145239153c10f72eece485d8.diff
LOG: [mlir] Fix memref_cast + subview folder when reducing rank
When the destination of the subview has a lower rank than its source we need to
fix the result type of the new subview op.
Differential Revision: https://reviews.llvm.org/D96804
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 49082912b803d..5582c0bde555c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3058,7 +3058,23 @@ isRankReducedType(Type originalType, Type candidateReducedType,
candidateLayout = getStridedLinearLayoutMap(candidateReduced);
else
candidateLayout = candidateReduced.getAffineMaps().front();
- if (inferredType != candidateLayout) {
+ assert(inferredType.getNumResults() == 1 &&
+ candidateLayout.getNumResults() == 1);
+ if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
+ inferredType.getNumDims() != candidateLayout.getNumDims()) {
+ if (errMsg) {
+ llvm::raw_string_ostream os(*errMsg);
+ os << "inferred type: " << inferredType;
+ }
+ return SubViewVerificationResult::AffineMapMismatch;
+ }
+ // Check that the
diff erence of the affine maps simplifies to 0.
+ AffineExpr
diff Expr =
+ inferredType.getResult(0) - candidateLayout.getResult(0);
+
diff Expr = simplifyAffineExpr(
diff Expr, inferredType.getNumDims(),
+ inferredType.getNumSymbols());
+ auto cst =
diff Expr.dyn_cast<AffineConstantExpr>();
+ if (!(cst && cst.getValue() == 0)) {
if (errMsg) {
llvm::raw_string_ostream os(*errMsg);
os << "inferred type: " << inferredType;
@@ -3344,11 +3360,29 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
/// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
/// the cast source operand type and the SubViewOp static information. This
/// is the resulting type if the MemRefCastOp were folded.
- Type resultType = SubViewOp::inferResultType(
- castOp.source().getType().cast<MemRefType>(),
- extractFromI64ArrayAttr(subViewOp.static_offsets()),
- extractFromI64ArrayAttr(subViewOp.static_sizes()),
- extractFromI64ArrayAttr(subViewOp.static_strides()));
+ auto resultType = SubViewOp::inferResultType(
+ castOp.source().getType().cast<MemRefType>(),
+ extractFromI64ArrayAttr(subViewOp.static_offsets()),
+ extractFromI64ArrayAttr(subViewOp.static_sizes()),
+ extractFromI64ArrayAttr(subViewOp.static_strides()))
+ .cast<MemRefType>();
+ uint32_t rankDiff =
+ subViewOp.getSourceType().getRank() - subViewOp.getType().getRank();
+ if (rankDiff > 0) {
+ auto shape = resultType.getShape();
+ auto projectedShape = shape.drop_front(rankDiff);
+ AffineMap map;
+ auto maps = resultType.getAffineMaps();
+ if (!maps.empty() && maps.front()) {
+ auto optionalUnusedDimsMask =
+ computeRankReductionMask(shape, projectedShape);
+ llvm::SmallDenseSet<unsigned> dimsToProject =
+ optionalUnusedDimsMask.getValue();
+ map = getProjectedMap(maps.front(), dimsToProject);
+ }
+ resultType = MemRefType::get(projectedShape, resultType.getElementType(),
+ map, resultType.getMemorySpace());
+ }
Value newSubView = rewriter.create<SubViewOp>(
subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 7b54938b0c488..c864af8f5747d 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -143,3 +143,17 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
%1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
return %1 : memref<?x?x16x32xi8>
}
+
+// CHECK-LABEL: func @subview_of_memcast
+// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
+// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+// CHECK: %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
+// CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}>
+func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+ memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
+ %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+ %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
+ memref<?x?x16x32xi8> to
+ memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+ return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+}
More information about the Mlir-commits
mailing list