[Mlir-commits] [mlir] 7a99602 - [mlir] Convert memref_reshape to memref_reinterpret_cast.

Alexander Belyaev llvmlistbot at llvm.org
Wed Oct 28 13:15:52 PDT 2020


Author: Alexander Belyaev
Date: 2020-10-28T21:15:32+01:00
New Revision: 7a996027b9847d9808cb5567e8a4553989e1dbcf

URL: https://github.com/llvm/llvm-project/commit/7a996027b9847d9808cb5567e8a4553989e1dbcf
DIFF: https://github.com/llvm/llvm-project/commit/7a996027b9847d9808cb5567e8a4553989e1dbcf.diff

LOG: [mlir] Convert memref_reshape to memref_reinterpret_cast.

Differential Revision: https://reviews.llvm.org/D90235

Added: 
    mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp
    mlir/test/Dialect/Standard/expand-memref-reshape.mlir
    mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp
    mlir/test/mlir-cpu-runner/memref_reshape.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b7c66f4d20f8..c5ad72aa02fc 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2262,6 +2262,20 @@ def MemRefReinterpretCastOp:
     I64ArrayAttr:$static_strides
   );
   let results = (outs AnyMemRef:$result);
+
+  let builders = [
+    // Build a ReinterpretCastOp with mixed static and dynamic entries.
+    OpBuilder<
+      "MemRefType resultType, Value source, int64_t staticOffset, "
+      "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+      "ValueRange offset, ValueRange sizes, ValueRange strides, "
+      "ArrayRef<NamedAttribute> attrs = {}">,
+    // Build a ReinterpretCastOp with all dynamic entries.
+    OpBuilder<
+      "MemRefType resultType, Value source, Value offset, ValueRange sizes, "
+      "ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">,
+  ];
+
   let extraClassDeclaration = extraBaseClassDeclaration # [{
     // The result of the op is always a ranked memref.
     MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
@@ -2312,7 +2326,7 @@ def MemRefReshapeOp: Std_Op<"memref_reshape", [
 
   let arguments = (ins
     AnyRankedOrUnrankedMemRef:$source,
-    MemRefRankOf<[AnySignlessInteger], [1]>:$shape
+    MemRefRankOf<[AnySignlessInteger, Index], [1]>:$shape
   );
   let results = (outs AnyRankedOrUnrankedMemRef:$result);
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 65a39dc9ad9b..714acdff997e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -25,6 +25,9 @@ class OwningRewritePatternList;
 /// Creates an instance of the ExpandAtomic pass.
 std::unique_ptr<Pass> createExpandAtomicPass();
 
+void populateExpandMemRefReshapePattern(OwningRewritePatternList &patterns,
+                                        MLIRContext *ctx);
+
 void populateExpandTanhPattern(OwningRewritePatternList &patterns,
                                MLIRContext *ctx);
 

diff  --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
index 83bda1048739..83f1941d079a 100644
--- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt
@@ -13,5 +13,6 @@ add_mlir_conversion_library(MLIRStandardToLLVM
 
   LINK_LIBS PUBLIC
   MLIRLLVMIR
+  MLIRStandardOpsTransforms
   MLIRTransforms
   )

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 3932e77ef5e7..152dbd1e990e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
@@ -3666,6 +3667,7 @@ void mlir::populateStdToLLVMConversionPatterns(
   populateStdToLLVMFuncOpConversionPattern(converter, patterns);
   populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
   populateStdToLLVMMemoryConversionPatterns(converter, patterns);
+  populateExpandMemRefReshapePattern(patterns, &converter.getContext());
 }
 
 /// Convert a non-empty list of types to be returned from a function into a

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 4638ae4e9268..48c3155cd105 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2269,6 +2269,34 @@ OpFoldResult MemRefCastOp::fold(ArrayRef<Attribute> operands) {
 // MemRefReinterpretCastOp
 //===----------------------------------------------------------------------===//
 
+void mlir::MemRefReinterpretCastOp::build(
+    OpBuilder &b, OperationState &result, MemRefType resultType, Value source,
+    int64_t staticOffset, ArrayRef<int64_t> staticSizes,
+    ArrayRef<int64_t> staticStrides, ValueRange offset, ValueRange sizes,
+    ValueRange strides, ArrayRef<NamedAttribute> attrs) {
+  build(b, result, resultType, source, offset, sizes, strides,
+        b.getI64ArrayAttr(staticOffset), b.getI64ArrayAttr(staticSizes),
+        b.getI64ArrayAttr(staticStrides));
+  result.addAttributes(attrs);
+}
+
+/// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`,
+/// `staticSizes` and `staticStrides` are  automatically filled with
+/// source-memref-rank sentinel values that encode dynamic entries.
+void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
+                                          MemRefType resultType, Value source,
+                                          Value offset, ValueRange sizes,
+                                          ValueRange strides,
+                                          ArrayRef<NamedAttribute> attrs) {
+  unsigned rank = resultType.getRank();
+  SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+  SmallVector<int64_t, 4> staticStridesVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  build(b, result, resultType, source,
+        /*staticOffset=*/ShapedType::kDynamicStrideOrOffset, staticSizesVector,
+        staticStridesVector, offset, sizes, strides, attrs);
+}
+
 /// Print of the form:
 /// ```
 ///   `name` ssa-name to
@@ -2391,18 +2419,6 @@ static LogicalResult verify(MemRefReinterpretCastOp op) {
           op.strides())))
     return failure();
 
-  // Extract source offset and strides.
-  int64_t resultOffset;
-  SmallVector<int64_t, 4> resultStrides;
-  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
-    return failure();
-
-  // Match offset in result memref type and in static_offsets attribute.
-  int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
-  if (resultOffset != expectedOffset)
-    return op.emitError("expected result type with offset = ")
-           << resultOffset << " instead of " << expectedOffset;
-
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto &en :
        llvm::enumerate(llvm::zip(resultType.getShape(),
@@ -2415,15 +2431,31 @@ static LogicalResult verify(MemRefReinterpretCastOp op) {
              << " in dim = " << en.index();
   }
 
-  // Match strides in result memref type and in static_strides attribute.
-  for (auto &en : llvm::enumerate(llvm::zip(
-           resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
-    int64_t resultStride = std::get<0>(en.value());
-    int64_t expectedStride = std::get<1>(en.value());
-    if (resultStride != expectedStride)
-      return op.emitError("expected result type with stride = ")
-             << expectedStride << " instead of " << resultStride
-             << " in dim = " << en.index();
+  // Match offset and strides in static_offset and static_strides attributes if
+  // result memref type has an affine map specified.
+  if (!resultType.getAffineMaps().empty()) {
+    int64_t resultOffset;
+    SmallVector<int64_t, 4> resultStrides;
+    if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+      return failure();
+
+    // Match offset in result memref type and in static_offsets attribute.
+    int64_t expectedOffset =
+        extractFromI64ArrayAttr(op.static_offsets()).front();
+    if (resultOffset != expectedOffset)
+      return op.emitError("expected result type with offset = ")
+             << resultOffset << " instead of " << expectedOffset;
+
+    // Match strides in result memref type and in static_strides attribute.
+    for (auto &en : llvm::enumerate(llvm::zip(
+             resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
+      int64_t resultStride = std::get<0>(en.value());
+      int64_t expectedStride = std::get<1>(en.value());
+      if (resultStride != expectedStride)
+        return op.emitError("expected result type with stride = ")
+               << expectedStride << " instead of " << resultStride
+               << " in dim = " << en.index();
+    }
   }
   return success();
 }

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index abdd05f56387..aabb81cf3d06 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
   Bufferize.cpp
   ExpandAtomic.cpp
+  ExpandMemRefReshape.cpp
   ExpandTanh.cpp
   FuncConversions.cpp
 

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp
new file mode 100644
index 000000000000..a8444c64a1c8
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandMemRefReshape.cpp
@@ -0,0 +1,70 @@
+//===- ExpandMemRefReshape.cpp - Code to perform expanding memref_reshape -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements expansion of MemRefReshapeOp into
+// MemRefReinterpretCastOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Converts `memref_reshape` that has a target shape of a statically-known
+/// size to `memref_reinterpret_cast`.
+struct MemRefReshapeOpConverter : public OpRewritePattern<MemRefReshapeOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MemRefReshapeOp op,
+                                PatternRewriter &rewriter) const final {
+    auto shapeType = op.shape().getType().cast<MemRefType>();
+    if (!shapeType.hasStaticShape())
+      return failure();
+
+    int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
+    SmallVector<Value, 4> sizes, strides;
+    sizes.resize(rank);
+    strides.resize(rank);
+
+    Location loc = op.getLoc();
+    Value stride = rewriter.create<ConstantIndexOp>(loc, 1);
+    for (int i = rank - 1; i >= 0; --i) {
+      Value index = rewriter.create<ConstantIndexOp>(loc, i);
+      Value size = rewriter.create<LoadOp>(loc, op.shape(), index);
+      if (!size.getType().isa<IndexType>())
+        size = rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
+      sizes[i] = size;
+      strides[i] = stride;
+      if (i > 0)
+        stride = rewriter.create<MulIOp>(loc, stride, size);
+    }
+    SmallVector<int64_t, 2> staticSizes(rank, ShapedType::kDynamicSize);
+    SmallVector<int64_t, 2> staticStrides(rank,
+                                          ShapedType::kDynamicStrideOrOffset);
+    rewriter.replaceOpWithNewOp<MemRefReinterpretCastOp>(
+        op, op.getType(), op.source(), /*staticOffset = */ 0, staticSizes,
+        staticStrides, /*offset=*/llvm::None, sizes, strides);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateExpandMemRefReshapePattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
+  patterns.insert<MemRefReshapeOpConverter>(ctx);
+}

diff  --git a/mlir/test/Dialect/Standard/expand-memref-reshape.mlir b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir
new file mode 100644
index 000000000000..0e6f4511c1b9
--- /dev/null
+++ b/mlir/test/Dialect/Standard/expand-memref-reshape.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -test-expand-memref-reshape | FileCheck %s
+
+// CHECK-LABEL: func @memref_reshape(
+func @memref_reshape(%input: memref<*xf32>,
+                     %shape: memref<3xi32>) -> memref<?x?x?xf32> {
+  %result = memref_reshape %input(%shape)
+               : (memref<*xf32>, memref<3xi32>) -> memref<?x?x?xf32>
+  return %result : memref<?x?x?xf32>
+}
+// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
+// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x?xf32> {
+// CHECK: [[C2:%.*]] = constant 2 : index
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[C0:%.*]] = constant 0 : index
+// CHECK: [[DIM_2:%.*]] = load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32>
+// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index
+// CHECK: [[DIM_1:%.*]] = load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
+// CHECK: [[SIZE_1:%.*]] = index_cast [[DIM_1]] : i32 to index
+// CHECK: [[STRIDE_0:%.*]] = muli [[SIZE_2]], [[SIZE_1]] : index
+// CHECK: [[DIM_0:%.*]] = load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
+// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index
+
+// CHECK: [[RESULT:%.*]] = memref_reinterpret_cast [[SRC]]
+// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]],
+// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[SIZE_2]], [[C1]]]
+// CHECK-SAME: : memref<*xf32> to memref<?x?x?xf32>

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 5bd94bd39d91..ea42028781ab 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -140,10 +140,10 @@ func @memref_reinterpret_cast_incompatible_memory_space(%in: memref<*xf32>) {
 
 // CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch
 func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
-  // expected-error @+1 {{expected result type with offset = 0 instead of 1}}
+  // expected-error @+1 {{expected result type with offset = 2 instead of 1}}
   %out = memref_reinterpret_cast %in to
            offset: [1], sizes: [10], strides: [1]
-         : memref<?xf32> to memref<10xf32>
+         : memref<?xf32> to memref<10xf32, offset: 2, strides: [1]>
   return
 }
 
@@ -164,8 +164,8 @@ func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) {
 func @memref_reinterpret_cast_offset_mismatch(%in: memref<?xf32>) {
   // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}}
   %out = memref_reinterpret_cast %in to
-           offset: [0], sizes: [10], strides: [2]
-         : memref<?xf32> to memref<10xf32>
+           offset: [2], sizes: [10], strides: [2]
+         : memref<?xf32> to memref<10xf32, offset: 2, strides: [1]>
   return
 }
 

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index e3e82f2febcf..aa22f3f7959c 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRTestTransforms
   TestAffineLoopParametricTiling.cpp
   TestBufferPlacement.cpp
+  TestExpandMemRefReshape.cpp
   TestExpandTanh.cpp
   TestCallGraph.cpp
   TestConstantFold.cpp

diff  --git a/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp
new file mode 100644
index 000000000000..ddb340a97ab4
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestExpandMemRefReshape.cpp
@@ -0,0 +1,37 @@
+//===- TestExpandMemRefReshape.cpp - Test expansion of memref_reshape -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for expanding memref reshape.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestExpandMemRefReshapePass
+    : public PassWrapper<TestExpandMemRefReshapePass, FunctionPass> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestExpandMemRefReshapePass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  populateExpandMemRefReshapePattern(patterns, &getContext());
+  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+void registerTestExpandMemRefReshapePass() {
+  PassRegistration<TestExpandMemRefReshapePass> pass(
+      "test-expand-memref-reshape", "Test expanding memref reshape");
+}
+} // namespace mlir

diff  --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
new file mode 100644
index 000000000000..96a8ae16ae6d
--- /dev/null
+++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm --print-ir-after-all \
+// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+
+func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
+
+func @main() -> () {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+
+  // Initialize input.
+  %input = alloc() : memref<2x3xf32>
+  %dim_x = dim %input, %c0 : memref<2x3xf32>
+  %dim_y = dim %input, %c1 : memref<2x3xf32>
+  scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) {
+    %prod = muli %i,  %dim_y : index
+    %val = addi %prod, %j : index
+    %val_i64 = index_cast %val : index to i64
+    %val_f32 = sitofp %val_i64 : i64 to f32
+    store %val_f32, %input[%i, %j] : memref<2x3xf32>
+  }
+  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+  call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1]
+  // CHECK-NEXT: [0,   1,   2]
+  // CHECK-NEXT: [3,   4,   5]
+
+  // Initialize shape.
+  %shape = alloc() : memref<2xindex>
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  store %c3, %shape[%c0] : memref<2xindex>
+  store %c2, %shape[%c1] : memref<2xindex>
+
+  // Test cases.
+  call @reshape_ranked_memref_to_ranked(%input, %shape)
+    : (memref<2x3xf32>, memref<2xindex>) -> ()
+  call @reshape_unranked_memref_to_ranked(%input, %shape)
+    : (memref<2x3xf32>, memref<2xindex>) -> ()
+  return
+}
+
+func @reshape_ranked_memref_to_ranked(%input : memref<2x3xf32>,
+                                      %shape : memref<2xindex>) {
+  %output = memref_reshape %input(%shape)
+                : (memref<2x3xf32>, memref<2xindex>) -> memref<?x?xf32>
+
+  %unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+  // CHECK: [0,   1],                          
+  // CHECK: [2,   3],                          
+  // CHECK: [4,   5]  
+  return
+}
+
+func @reshape_unranked_memref_to_ranked(%input : memref<2x3xf32>,
+                                        %shape : memref<2xindex>) {
+  %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32>
+  %output = memref_reshape %input(%shape)
+                : (memref<2x3xf32>, memref<2xindex>) -> memref<?x?xf32>
+
+  %unranked_output = memref_cast %output : memref<?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> ()
+  // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data =
+  // CHECK: [0,   1],                          
+  // CHECK: [2,   3],                          
+  // CHECK: [4,   5]  
+  return
+}

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2322de20230d..ac29305e3c6f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -53,6 +53,7 @@ void registerTestConvertGPUKernelToHsacoPass();
 void registerTestDominancePass();
 void registerTestDialect(DialectRegistry &);
 void registerTestDynamicPipelinePass();
+void registerTestExpandMemRefReshapePass();
 void registerTestExpandTanhPass();
 void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
@@ -115,6 +116,7 @@ void registerTestPasses() {
   registerTestDynamicPipelinePass();
   registerTestFunc();
   registerTestExpandTanhPass();
+  registerTestExpandMemRefReshapePass();
   registerTestGpuMemoryPromotionPass();
   registerTestInterfaces();
   registerTestLinalgCodegenStrategy();


        


More information about the Mlir-commits mailing list