[Mlir-commits] [mlir] 7a99602 - [mlir] Convert memref_reshape to memref_reinterpret_cast.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Oct 28 13:15:52 PDT 2020
Author: Alexander Belyaev
Date: 2020-10-28T21:15:32+01:00
New Revision: 7a996027b9847d9808cb5567e8a4553989e1dbcf
URL: https://github.com/llvm/llvm-project/commit/7a996027b9847d9808cb5567e8a4553989e1dbcf
DIFF: https://github.com/llvm/llvm-project/commit/7a996027b9847d9808cb5567e8a4553989e1dbcf.diff
LOG: [mlir] Convert memref_reshape to memref_reinterpret_cast.
Differential Revision: https://reviews.llvm.org/D90235
Added:
mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp
mlir/test/Dialect/Standard/expand-memref-reshape.mlir
mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp
mlir/test/mlir-cpu-runner/memref_reshape.mlir
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/lib/Transforms/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b7c66f4d20f8..c5ad72aa02fc 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2262,6 +2262,20 @@ def MemRefReinterpretCastOp:
I64ArrayAttr:$static_strides
);
let results = (outs AnyMemRef:$result);
+
+ let builders = [
+ // Build a ReinterpretCastOp with mixed static and dynamic entries.
+ OpBuilder<
+ "MemRefType resultType, Value source, int64_t staticOffset, "
+ "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+ "ValueRange offset, ValueRange sizes, ValueRange strides, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ // Build a ReinterpretCastOp with all dynamic entries.
+ OpBuilder<
+ "MemRefType resultType, Value source, Value offset, ValueRange sizes, "
+ "ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">,
+ ];
+
let extraClassDeclaration = extraBaseClassDeclaration # [{
// The result of the op is always a ranked memref.
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
@@ -2312,7 +2326,7 @@ def MemRefReshapeOp: Std_Op<"memref_reshape", [
let arguments = (ins
AnyRankedOrUnrankedMemRef:$source,
- MemRefRankOf<[AnySignlessInteger], [1]>:$shape
+ MemRefRankOf<[AnySignlessInteger, Index], [1]>:$shape
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 65a39dc9ad9b..714acdff997e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -25,6 +25,9 @@ class OwningRewritePatternList;
/// Creates an instance of the ExpandAtomic pass.
std::unique_ptr<Pass> createExpandAtomicPass();
+void populateExpandMemRefReshapePattern(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
+
void populateExpandTanhPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx);
diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
index 83bda1048739..83f1941d079a 100644
--- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
@@ -13,5 +13,6 @@ add_mlir_conversion_library(MLIRStandardToLLVM
LINK_LIBS PUBLIC
MLIRLLVMIR
+ MLIRStandardOpsTransforms
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3932e77ef5e7..152dbd1e990e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -3666,6 +3667,7 @@ void mlir::populateStdToLLVMConversionPatterns(
populateStdToLLVMFuncOpConversionPattern(converter, patterns);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
populateStdToLLVMMemoryConversionPatterns(converter, patterns);
+ populateExpandMemRefReshapePattern(patterns, &converter.getContext());
}
/// Convert a non-empty list of types to be returned from a function into a
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 4638ae4e9268..48c3155cd105 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2269,6 +2269,34 @@ OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
// MemRefReinterpretCastOp
//===----------------------------------------------------------------------===//
+void mlir::MemRefReinterpretCastOp::build(
+ OpBuilder &b, OperationState &result, MemRefType resultType, Value source,
+ int64_t staticOffset, ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides, ValueRange offset, ValueRange sizes,
+ ValueRange strides, ArrayRef<NamedAttribute> attrs) {
+ build(b, result, resultType, source, offset, sizes, strides,
+ b.getI64ArrayAttr(staticOffset), b.getI64ArrayAttr(staticSizes),
+ b.getI64ArrayAttr(staticStrides));
+ result.addAttributes(attrs);
+}
+
+/// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`,
+/// `staticSizes` and `staticStrides` are automatically filled with
+/// source-memref-rank sentinel values that encode dynamic entries.
+void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+ MemRefType resultType, Value source,
+ Value offset, ValueRange sizes,
+ ValueRange strides,
+ ArrayRef<NamedAttribute> attrs) {
+ unsigned rank = resultType.getRank();
+ SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+ SmallVector<int64_t, 4> staticStridesVector(
+ rank, ShapedType::kDynamicStrideOrOffset);
+ build(b, result, resultType, source,
+ /*staticOffset=*/ShapedType::kDynamicStrideOrOffset, staticSizesVector,
+ staticStridesVector, offset, sizes, strides, attrs);
+}
+
/// Print of the form:
/// ```
/// `name` ssa-name to
@@ -2391,18 +2419,6 @@ static LogicalResult verify(MemRefReinterpretCastOp op) {
op.strides())))
return failure();
- // Extract source offset and strides.
- int64_t resultOffset;
- SmallVector<int64_t, 4> resultStrides;
- if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
- return failure();
-
- // Match offset in result memref type and in static_offsets attribute.
- int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
- if (resultOffset != expectedOffset)
- return op.emitError("expected result type with offset = ")
- << resultOffset << " instead of " << expectedOffset;
-
// Match sizes in result memref type and in static_sizes attribute.
for (auto &en :
llvm::enumerate(llvm::zip(resultType.getShape(),
@@ -2415,15 +2431,31 @@ static LogicalResult verify(MemRefReinterpretCastOp op) {
<< " in dim = " << en.index();
}
- // Match strides in result memref type and in static_strides attribute.
- for (auto &en : llvm::enumerate(llvm::zip(
- resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
- int64_t resultStride = std::get<0>(en.value());
- int64_t expectedStride = std::get<1>(en.value());
- if (resultStride != expectedStride)
- return op.emitError("expected result type with stride = ")
- << expectedStride << " instead of " << resultStride
- << " in dim = " << en.index();
+ // Match offset and strides in static_offset and static_strides attributes if
+ // result memref type has an affine map specified.
+ if (!resultType.getAffineMaps().empty()) {
+ int64_t resultOffset;
+ SmallVector<int64_t, 4> resultStrides;
+ if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+ return failure();
+
+ // Match offset in result memref type and in static_offsets attribute.
+ int64_t expectedOffset =
+ extractFromI64ArrayAttr(op.static_offsets()).front();
+ if (resultOffset != expectedOffset)
+ return op.emitError("expected result type with offset = ")
+ << resultOffset << " instead of " << expectedOffset;
+
+ // Match strides in result memref type and in static_strides attribute.
+ for (auto &en : llvm::enumerate(llvm::zip(
+ resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+ int64_t resultStride = std::get<0>(en.value());
+ int64_t expectedStride = std::get<1>(en.value());
+ if (resultStride != expectedStride)
+ return op.emitError("expected result type with stride = ")
+ << expectedStride << " instead of " << resultStride
+ << " in dim = " << en.index();
+ }
}
return success();
}
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index abdd05f56387..aabb81cf3d06 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRStandardOpsTransforms
Bufferize.cpp
ExpandAtomic.cpp
+ ExpandMemRefReshape.cpp
ExpandTanh.cpp
FuncConversions.cpp
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp
new file mode 100644
index 000000000000..a8444c64a1c8
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp
@@ -0,0 +1,70 @@
+//===- ExpandMemRefReshape.cpp - Code to perform expanding memref_reshape -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements expansion of MemRefReshapeOp into
+// MemRefReinterpretCastOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Converts `memref_reshape` that has a target shape of a statically-known
+/// size to `memref_reinterpret_cast`.
+struct MemRefReshapeOpConverter : public OpRewritePattern<MemRefReshapeOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MemRefReshapeOp op,
+ PatternRewriter &rewriter) const final {
+ auto shapeType = op.shape().getType().cast<MemRefType>();
+ if (!shapeType.hasStaticShape())
+ return failure();
+
+ int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
+ SmallVector<Value, 4> sizes, strides;
+ sizes.resize(rank);
+ strides.resize(rank);
+
+ Location loc = op.getLoc();
+ Value stride = rewriter.create<ConstantIndexOp>(loc, 1);
+ for (int i = rank - 1; i >= 0; --i) {
+ Value index = rewriter.create<ConstantIndexOp>(loc, i);
+ Value size = rewriter.create<LoadOp>(loc, op.shape(), index);
+ if (!size.getType().isa<IndexType>())
+ size = rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
+ sizes[i] = size;
+ strides[i] = stride;
+ if (i > 0)
+ stride = rewriter.create<MulIOp>(loc, stride, size);
+ }
+ SmallVector<int64_t, 2> staticSizes(rank, ShapedType::kDynamicSize);
+ SmallVector<int64_t, 2> staticStrides(rank,
+ ShapedType::kDynamicStrideOrOffset);
+ rewriter.replaceOpWithNewOp<MemRefReinterpretCastOp>(
+ op, op.getType(), op.source(), /*staticOffset = */ 0, staticSizes,
+ staticStrides, /*offset=*/llvm::None, sizes, strides);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateExpandMemRefReshapePattern(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns.insert<MemRefReshapeOpConverter>(ctx);
+}
diff --git a/mlir/test/Dialect/Standard/expand-memref-reshape.mlir b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir
new file mode 100644
index 000000000000..0e6f4511c1b9
--- /dev/null
+++ b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -test-expand-memref-reshape | FileCheck %s
+
+// CHECK-LABEL: func @memref_reshape(
+func @memref_reshape(%input: memref<*xf32>,
+ %shape: memref<3xi32>) -> memref<?x?x?xf32> {
+ %result = memref_reshape %input(%shape)
+ : (memref<*xf32>, memref<3xi32>) -> memref<?x?x?xf32>
+ return %result : memref<?x?x?xf32>
+}
+// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
+// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x?xf32> {
+// CHECK: [[C2:%.*]] = constant 2 : index
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[C0:%.*]] = constant 0 : index
+// CHECK: [[DIM_2:%.*]] = load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32>
+// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index
+// CHECK: [[DIM_1:%.*]] = load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
+// CHECK: [[SIZE_1:%.*]] = index_cast [[DIM_1]] : i32 to index
+// CHECK: [[STRIDE_0:%.*]] = muli [[SIZE_2]], [[SIZE_1]] : index
+// CHECK: [[DIM_0:%.*]] = load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
+// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index
+
+// CHECK: [[RESULT:%.*]] = memref_reinterpret_cast [[SRC]]
+// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]],
+// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[SIZE_2]], [[C1]]]
+// CHECK-SAME: : memref<*xf32> to memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 5bd94bd39d91..ea42028781ab 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -140,10 +140,10 @@ func @memref_reinterpret_cast_incompatible_memory_space(%in: memref<*xf32>) {
// CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch
func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
- // expected-error @+1 {{expected result type with offset = 0 instead of 1}}
+ // expected-error @+1 {{expected result type with offset = 2 instead of 1}}
%out = memref_reinterpret_cast %in to
offset: [1], sizes: [10], strides: [1]
- : memref<?xf32> to memref<10xf32>
+ : memref<?xf32> to memref<10xf32, offset: 2, strides: [1]>
return
}
@@ -164,8 +164,8 @@ func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) {
func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}}
%out = memref_reinterpret_cast %in to
- offset: [0], sizes: [10], strides: [2]
- : memref<?xf32> to memref<10xf32>
+ offset: [2], sizes: [10], strides: [2]
+ : memref<?xf32> to memref<10xf32, offset: 2, strides: [1]>
return
}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index e3e82f2febcf..aa22f3f7959c 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRTestTransforms
TestAffineLoopParametricTiling.cpp
TestBufferPlacement.cpp
+ TestExpandMemRefReshape.cpp
TestExpandTanh.cpp
TestCallGraph.cpp
TestConstantFold.cpp
diff --git a/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp
new file mode 100644
index 000000000000..ddb340a97ab4
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp
@@ -0,0 +1,37 @@
+//===- TestExpandMemRefReshape.cpp - Test expansion of memref_reshape -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for expanding memref reshape.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestExpandMemRefReshapePass
+ : public PassWrapper<TestExpandMemRefReshapePass, FunctionPass> {
+ void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestExpandMemRefReshapePass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ populateExpandMemRefReshapePattern(patterns, &getContext());
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+void registerTestExpandMemRefReshapePass() {
+ PassRegistration<TestExpandMemRefReshapePass> pass(
+ "test-expand-memref-reshape", "Test expanding memref reshape");
+}
+} // namespace mlir
diff --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
new file mode 100644
index 000000000000..96a8ae16ae6d
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm --print-ir-after-all \
+// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
+
+func @main() -> () {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+
+ // Initialize input.
+ %input = alloc() : memref<2x3xf32>
+ %dim_x = dim %input, %c0 : memref<2x3xf32>
+ %dim_y = dim %input, %c1 : memref<2x3xf32>
+ scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
+ %prod = muli %i, %dim_y : index
+ %val = addi %prod, %j : index
+ %val_i64 = index_cast %val : index to i64
+ %val_f32 = sitofp %val_i64 : i64 to f32
+ store %val_f32, %input[%i, %j] : memref<2x3xf32>
+ }
+ %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+ call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
+ // CHECK-NEXT: [0, 1, 2]
+ // CHECK-NEXT: [3, 4, 5]
+
+ // Initialize shape.
+ %shape = alloc() : memref<2xindex>
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ store %c3, %shape[%c0] : memref<2xindex>
+ store %c2, %shape[%c1] : memref<2xindex>
+
+ // Test cases.
+ call @reshape_ranked_memref_to_ranked(%input, %shape)
+ : (memref<2x3xf32>, memref<2xindex>) -> ()
+ call @reshape_unranked_memref_to_ranked(%input, %shape)
+ : (memref<2x3xf32>, memref<2xindex>) -> ()
+ return
+}
+
+func @reshape_ranked_memref_to_ranked(%input : memref<2x3xf32>,
+ %shape : memref<2xindex>) {
+ %output = memref_reshape %input(%shape)
+ : (memref<2x3xf32>, memref<2xindex>) -> memref<?x?xf32>
+
+ %unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
+ call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+ // CHECK: [0, 1],
+ // CHECK: [2, 3],
+ // CHECK: [4, 5]
+ return
+}
+
+func @reshape_unranked_memref_to_ranked(%input : memref<2x3xf32>,
+ %shape : memref<2xindex>) {
+ %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+ %output = memref_reshape %input(%shape)
+ : (memref<2x3xf32>, memref<2xindex>) -> memref<?x?xf32>
+
+ %unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
+ call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+ // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+ // CHECK: [0, 1],
+ // CHECK: [2, 3],
+ // CHECK: [4, 5]
+ return
+}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2322de20230d..ac29305e3c6f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -53,6 +53,7 @@ void registerTestConvertGPUKernelToHsacoPass();
void registerTestDominancePass();
void registerTestDialect(DialectRegistry &);
void registerTestDynamicPipelinePass();
+void registerTestExpandMemRefReshapePass();
void registerTestExpandTanhPass();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
@@ -115,6 +116,7 @@ void registerTestPasses() {
registerTestDynamicPipelinePass();
registerTestFunc();
registerTestExpandTanhPass();
+ registerTestExpandMemRefReshapePass();
registerTestGpuMemoryPromotionPass();
registerTestInterfaces();
registerTestLinalgCodegenStrategy();
More information about the Mlir-commits
mailing list