[Mlir-commits] [mlir] 6c3c5f8 - [mlir][memref] Improve type inference for rank-reducing subviews
Matthias Springer
llvmlistbot at llvm.org
Tue Jul 5 07:49:18 PDT 2022
Author: Matthias Springer
Date: 2022-07-05T16:49:07+02:00
New Revision: 6c3c5f8069d97e635b1887a6f9ac410391b89fae
URL: https://github.com/llvm/llvm-project/commit/6c3c5f8069d97e635b1887a6f9ac410391b89fae
DIFF: https://github.com/llvm/llvm-project/commit/6c3c5f8069d97e635b1887a6f9ac410391b89fae.diff
LOG: [mlir][memref] Improve type inference for rank-reducing subviews
The result shape of a rank-reducing subview cannot be inferred in the general case. Just the result rank is not enough. The only thing that we can infer is the layout map.
This change also improves the bufferization patterns of tensor.extract_slice and tensor.insert_slice to fully support rank-reducing operations.
Differential Revision: https://reviews.llvm.org/D129144
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
mlir/unittests/Dialect/MemRef/InferShapeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 097ce28413474..daeb7b896a2e4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1645,12 +1645,20 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
ArrayRef<OpFoldResult> staticStrides);
- static Type inferRankReducedResultType(unsigned resultRank,
+
+ /// A rank-reducing result type can be inferred from the desired result
+ /// shape. Only the layout map is inferred.
+ ///
+ /// Note: The result shape cannot be inferred with just the result rank and
+ /// and the desired sizes. In case there are more "ones" among the sizes
+ /// than the
diff erence in source/result rank, it is not clear which dims of
+ /// size one should be dropped.
+ static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceMemRefType,
ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides);
- static Type inferRankReducedResultType(unsigned resultRank,
+ static Type inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceMemRefType,
ArrayRef<OpFoldResult> staticOffsets,
ArrayRef<OpFoldResult> staticSizes,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
index 6c6bcabb499d9..719797ac23473 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp
@@ -215,25 +215,10 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
/*rewriteFunc=*/
[](OpBuilder &b, Location loc, OpOperand &operand) {
auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
- // Expand offsets, sizes and strides to the full rank to handle the
- // rank-reducing case.
- SmallVector<OpFoldResult> mixedOffsets = insertOp.getMixedOffsets();
- SmallVector<OpFoldResult> mixedSizes = insertOp.getMixedSizes();
- SmallVector<OpFoldResult> mixedStrides = insertOp.getMixedStrides();
- OffsetSizeAndStrideOpInterface::expandToRank(
- insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides,
- [&](Value target, int64_t dim) -> OpFoldResult {
- auto shapedType = target.getType().cast<ShapedType>();
- if (shapedType.isDynamicDim(dim))
- return b.create<tensor::DimOp>(loc, target, dim).getResult();
- return b.getIndexAttr(shapedType.getDimSize(dim));
- });
- auto t = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
- insertOp.getSourceType().getRank(),
- insertOp.getDest().getType().cast<RankedTensorType>(), mixedOffsets,
- mixedSizes, mixedStrides);
auto extractOp = b.create<tensor::ExtractSliceOp>(
- loc, t, insertOp.getDest(), mixedOffsets, mixedSizes, mixedStrides);
+ loc, insertOp.getSourceType(), insertOp.getDest(),
+ insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
+ insertOp.getMixedStrides());
return extractOp.getResult();
});
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 000bac1b3129b..8e54936cd43c7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2145,7 +2145,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
staticSizes, staticStrides);
}
-Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
+Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceRankedTensorType,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
@@ -2153,27 +2153,26 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
auto inferredType =
inferResultType(sourceRankedTensorType, offsets, sizes, strides)
.cast<MemRefType>();
- assert(inferredType.getRank() >= resultRank && "expected ");
- int rankDiff = inferredType.getRank() - resultRank;
- if (rankDiff > 0) {
- auto shape = inferredType.getShape();
- llvm::SmallBitVector dimsToProject =
- getPositionsOfShapeOne(rankDiff, shape);
- SmallVector<int64_t> projectedShape;
- for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
- if (!dimsToProject.test(pos))
- projectedShape.push_back(shape[pos]);
-
- AffineMap map =
- getProjectedMap(inferredType.getLayout().getAffineMap(), dimsToProject);
- inferredType =
- MemRefType::get(projectedShape, inferredType.getElementType(), map,
- inferredType.getMemorySpace());
- }
- return inferredType;
-}
-
-Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
+ assert(inferredType.getRank() >= resultShape.size() && "expected ");
+ if (inferredType.getRank() == resultShape.size())
+ return inferredType;
+
+ // Compute which dimensions are dropped.
+ Optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
+ computeRankReductionMask(inferredType.getShape(), resultShape);
+ assert(dimsToProject.hasValue() && "invalid rank reduction");
+ llvm::SmallBitVector dimsToProjectVector(inferredType.getRank());
+ for (unsigned dim : *dimsToProject)
+ dimsToProjectVector.set(dim);
+
+ // Compute layout map and result type.
+ AffineMap map = getProjectedMap(inferredType.getLayout().getAffineMap(),
+ dimsToProjectVector);
+ return MemRefType::get(resultShape, inferredType.getElementType(), map,
+ inferredType.getMemorySpace());
+}
+
+Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
MemRefType sourceRankedTensorType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
@@ -2187,9 +2186,10 @@ Type SubViewOp::inferRankReducedResultType(unsigned resultRank,
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
ShapedType::kDynamicStrideOrOffset);
return SubViewOp::inferRankReducedResultType(
- resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
+ resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
staticStrides);
}
+
// Build a SubViewOp with mixed static and dynamic entries and custom result
// type. If the type passed is nullptr, it is inferred.
void SubViewOp::build(OpBuilder &b, OperationState &result,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index 2c09145c4f3a4..51f6a69f5b6a4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -44,7 +44,7 @@ static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
}
builder.setInsertionPoint(subviewUse);
Type newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getRank(), val.getType().cast<MemRefType>(),
+ subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
extractFromI64ArrayAttr(subviewUse.static_offsets()),
extractFromI64ArrayAttr(subviewUse.static_sizes()),
extractFromI64ArrayAttr(subviewUse.static_strides()));
@@ -136,7 +136,7 @@ LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
sizes.push_back(builder.getIndexAttr(size));
auto dstMemref =
memref::SubViewOp::inferRankReducedResultType(
- allocOp.getType().getRank(), newMemref, offsets, sizes, strides)
+ allocOp.getType().getShape(), newMemref, offsets, sizes, strides)
.cast<MemRefType>();
Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
offsets, sizes, strides);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 784bd8e1263bb..97da5969a3004 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -278,36 +278,24 @@ struct ExtractSliceOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
Location loc = extractSliceOp.getLoc();
- // Even if this op was decided to bufferize out-of-place, do not insert the
- // buffer copy yet. This is done later in this function.
+ // Get source buffer.
FailureOr<Value> srcMemref =
getBuffer(rewriter, extractSliceOp.getSource(), options);
if (failed(srcMemref))
return failure();
auto srcMemrefType = srcMemref->getType().cast<MemRefType>();
- auto dstTensorType =
- extractSliceOp.getResult().getType().cast<RankedTensorType>();
- // Expand offsets, sizes and strides to the full rank to handle the
- // rank-reducing case.
- SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
- OffsetSizeAndStrideOpInterface::expandToRank(
- *srcMemref, mixedOffsets, mixedSizes, mixedStrides,
- [&](Value target, int64_t dim) -> OpFoldResult {
- auto shapedType = target.getType().cast<ShapedType>();
- if (shapedType.isDynamicDim(dim))
- return rewriter.create<memref::DimOp>(loc, target, dim).result();
- return rewriter.getIndexAttr(shapedType.getDimSize(dim));
- });
- // Bufferize to subview.
- auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType(
- dstTensorType.getRank(), srcMemrefType,
- mixedOffsets, mixedSizes, mixedStrides)
- .cast<MemRefType>();
+ // Take a subview of the source buffer.
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ extractSliceOp.getType().getShape(), srcMemrefType, mixedOffsets,
+ mixedSizes, mixedStrides)
+ .cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes,
mixedStrides);
@@ -690,30 +678,22 @@ struct InsertSliceOpInterface
// catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
Location loc = insertSliceOp.getLoc();
+
+ // Get destination buffer.
FailureOr<Value> dstMemref =
getBuffer(rewriter, insertSliceOp.getDest(), options);
if (failed(dstMemref))
return failure();
- // Expand offsets, sizes and strides to the full rank to handle the
- // rank-reducing case.
- SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
- OffsetSizeAndStrideOpInterface::expandToRank(
- *dstMemref, mixedOffsets, mixedSizes, mixedStrides,
- [&](Value target, int64_t dim) -> OpFoldResult {
- auto shapedType = target.getType().cast<ShapedType>();
- if (shapedType.isDynamicDim(dim))
- return rewriter.create<memref::DimOp>(loc, target, dim).result();
- return rewriter.getIndexAttr(shapedType.getDimSize(dim));
- });
- // Take a subview of the dst.
+ // Take a subview of the destination buffer.
auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
- insertSliceOp.getSourceType().getRank(), dstMemrefType,
+ insertSliceOp.getSourceType().getShape(), dstMemrefType,
mixedOffsets, mixedSizes, mixedStrides)
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
@@ -946,11 +926,22 @@ struct ParallelInsertSliceOpInterface
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
if (failed(srcBuffer))
return failure();
+
+ // Take a subview of the destination buffer.
+ auto destBufferType = destBuffer->getType().cast<MemRefType>();
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
+ parallelInsertSliceOp.getMixedOffsets(),
+ parallelInsertSliceOp.getMixedSizes(),
+ parallelInsertSliceOp.getMixedStrides())
+ .cast<MemRefType>();
Value subview = rewriter.create<memref::SubViewOp>(
- parallelInsertSliceOp.getLoc(), *destBuffer,
+ parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
parallelInsertSliceOp.getMixedStrides());
+
// This memcpy will fold away if everything bufferizes in-place.
if (failed(options.createMemCpy(rewriter, parallelInsertSliceOp.getLoc(),
*srcBuffer, subview)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 198ece16baebe..6cddef218ca63 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -216,8 +216,10 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(
+ llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
- 0, inputType, offsets, sizes, strides);
+ targetShape, inputType, offsets, sizes, strides);
return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
}
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 6a3c4e1f87a03..937588e045bba 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -292,7 +292,7 @@ func.func @tensor.extract_slice_rank_reducing(
// CHECK-SAME: %[[t1:.*]]: tensor<?x?xf32>, %[[t2:.*]]: tensor<?x10xf32>,
// CHECK-SAME: %[[idx1:.*]]: index, %[[idx2:.*]]: index
func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
- %idx1: index, %idx2: index) -> tensor<?x?xf32> {
+ %idx1: index, %idx2: index) -> tensor<?x?xf32> {
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
@@ -313,6 +313,40 @@ func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
// -----
+// CHECK: #[[$MAP11:.*]] = affine_map<()[s0] -> (s0)>
+
+// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_1(
+func.func @tensor.insert_slice_rank_reducing_1(
+ %t1: tensor<?x?xf32>, %f: tensor<f32>, %idx1: index, %idx2: index)
+ -> tensor<?x?xf32>
+{
+ // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?xf32>
+ // CHECK: memref.subview %[[alloc]][%{{.*}}, %{{.*}}] [1, 1] [1, 1] : memref<?x?xf32> to memref<f32, #[[$MAP11]]>
+ // CHECK: memref.copy {{.*}} : memref<f32> to memref<f32, #[[$MAP11]]>
+ %0 = tensor.insert_slice %f into %t1[%idx1, %idx2][1, 1][1, 1]
+ : tensor<f32> into tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP12:.*]] = affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5)>
+
+// CHECK-LABEL: func @tensor.insert_slice_rank_reducing_2(
+func.func @tensor.insert_slice_rank_reducing_2(
+ %t1: tensor<?x?x?x?x?x?x?xf32>, %t2: tensor<2x1x4x1x1xf32>, %i: index)
+ -> tensor<?x?x?x?x?x?x?xf32>
+{
+ // CHECK: %[[alloc:.*]] = memref.alloc{{.*}} : memref<?x?x?x?x?x?x?xf32>
+ // CHECK: memref.subview %[[alloc]][{{.*}}] [1, 2, 1, 4, 1, 1, 1] [1, 1, 1, 1, 1, 1, 1] : memref<?x?x?x?x?x?x?xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]>
+ // CHECK: memref.copy {{.*}} : memref<2x1x4x1x1xf32> to memref<2x1x4x1x1xf32, #[[$MAP12]]>
+ %0 = tensor.insert_slice %t2 into %t1[%i, %i, %i, %i, %i, %i, %i][1, 2, 1, 4, 1, 1, 1][1, 1, 1, 1, 1, 1, 1]
+ : tensor<2x1x4x1x1xf32> into tensor<?x?x?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?x?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.insert(
// CHECK-SAME: %[[t1:.*]]: tensor<5xf32>, %[[idx1:.*]]: index,
// CHECK-SAME: %[[f:.*]]: f32
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 7249d546de449..4b462f6e2a92b 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -193,3 +193,27 @@ func.func @rank_reducing(
}
return %5: tensor<?x1x6x8xf32>
}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+
+// CHECK-LABEL: func.func @rank_reducing_parallel_insert_slice
+func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tensor<200x100xf32>) {
+ %c1 = arith.constant 1 : index
+ %num_threads = arith.constant 100 : index
+
+ // CHECK: scf.foreach_thread {{.*}} {
+ %result = scf.foreach_thread (%thread_idx) in (%num_threads) -> tensor<200x100xf32> {
+ %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32>
+ scf.foreach_thread.perform_concurrently {
+ // CHECK: memref.subview %{{.*}}[%{{.*}}] [1] [1] : memref<100xf32, #[[$MAP0]]> to memref<1xf32, #[[$MAP0]]>
+ // CHECK: memref.subview %{{.*}}[1, %{{.*}}] [1, 1] [1, 1] : memref<200x100xf32, #[[$MAP1]]> to memref<1xf32, #[[$MAP0]]>
+ tensor.parallel_insert_slice %1 into %out[1, %thread_idx][1, 1][1, 1] :
+ tensor<1xf32> into tensor<200x100xf32>
+ }
+ }
+ // CHECK: }
+ return
+}
diff --git a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp
index 189975590cfbb..28dc768bda25a 100644
--- a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp
+++ b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp
@@ -21,7 +21,7 @@ TEST(InferShapeTest, inferRankReducedShapeIdentity) {
OpBuilder b(&ctx);
auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType());
auto reducedType = SubViewOp::inferRankReducedResultType(
- /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1});
+ /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
AffineExpr dim0;
bindDims(&ctx, dim0);
auto expectedType =
@@ -38,7 +38,7 @@ TEST(InferShapeTest, inferRankReducedShapeNonIdentity) {
auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
AffineMap::get(2, 0, 1000 * dim0 + dim1));
auto reducedType = SubViewOp::inferRankReducedResultType(
- /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1});
+ /*resultShape=*/{2}, sourceMemref, {2, 3}, {1, 2}, {1, 1});
auto expectedType =
MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003));
EXPECT_EQ(reducedType, expectedType);
@@ -52,7 +52,7 @@ TEST(InferShapeTest, inferRankReducedShapeToScalar) {
auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(),
AffineMap::get(2, 0, 1000 * dim0 + dim1));
auto reducedType = SubViewOp::inferRankReducedResultType(
- /*resultRank=*/0, sourceMemref, {2, 3}, {1, 1}, {1, 1});
+ /*resultShape=*/{}, sourceMemref, {2, 3}, {1, 1}, {1, 1});
auto expectedType =
MemRefType::get({}, b.getIndexType(),
AffineMap::get(0, 0, b.getAffineConstantExpr(2003)));
More information about the Mlir-commits
mailing list