[Mlir-commits] [mlir] b20e150 - [mlir] Use static shape knowledge when lowering memref.reshape
Benjamin Kramer
llvmlistbot at llvm.org
Tue May 11 09:21:15 PDT 2021
Author: Benjamin Kramer
Date: 2021-05-11T18:21:09+02:00
New Revision: b20e150c9be16f69c73f4cd2986053d13d0f376a
URL: https://github.com/llvm/llvm-project/commit/b20e150c9be16f69c73f4cd2986053d13d0f376a
DIFF: https://github.com/llvm/llvm-project/commit/b20e150c9be16f69c73f4cd2986053d13d0f376a.diff
LOG: [mlir] Use static shape knowledge when lowering memref.reshape
This is actually necessary for correctness, as memref.reinterpret_cast
doesn't verify if the output shape doesn't match the static sizes.
Differential Revision: https://reviews.llvm.org/D102232
Added:
Modified:
mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
mlir/test/Dialect/Standard/expand-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index fd1b36907ae1e..d4810ac40f31b 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -91,11 +91,19 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
Location loc = op.getLoc();
Value stride = rewriter.create<ConstantIndexOp>(loc, 1);
for (int i = rank - 1; i >= 0; --i) {
- Value index = rewriter.create<ConstantIndexOp>(loc, i);
- Value size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
- if (!size.getType().isa<IndexType>())
- size = rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
- sizes[i] = size;
+ Value size;
+ // Load dynamic sizes from the shape input, use constants for static dims.
+ if (op.getType().isDynamicDim(i)) {
+ Value index = rewriter.create<ConstantIndexOp>(loc, i);
+ size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
+ if (!size.getType().isa<IndexType>())
+ size =
+ rewriter.create<IndexCastOp>(loc, size, rewriter.getIndexType());
+ sizes[i] = size;
+ } else {
+ sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
+ size = rewriter.create<ConstantOp>(loc, sizes[i].get<Attribute>());
+ }
strides[i] = stride;
if (i > 0)
stride = rewriter.create<MulIOp>(loc, stride, size);
diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir
index f91f9b8ff373b..e133df36684fb 100644
--- a/mlir/test/Dialect/Standard/expand-ops.mlir
+++ b/mlir/test/Dialect/Standard/expand-ops.mlir
@@ -84,19 +84,17 @@ func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
// CHECK-LABEL: func @memref_reshape(
func @memref_reshape(%input: memref<*xf32>,
- %shape: memref<3xi32>) -> memref<?x?x?xf32> {
+ %shape: memref<3xi32>) -> memref<?x?x8xf32> {
%result = memref.reshape %input(%shape)
- : (memref<*xf32>, memref<3xi32>) -> memref<?x?x?xf32>
- return %result : memref<?x?x?xf32>
+ : (memref<*xf32>, memref<3xi32>) -> memref<?x?x8xf32>
+ return %result : memref<?x?x8xf32>
}
// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
-// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x?xf32> {
+// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {
// CHECK: [[C1:%.*]] = constant 1 : index
-// CHECK: [[C2:%.*]] = constant 2 : index
-// CHECK: [[DIM_2:%.*]] = memref.load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32>
-// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index
-// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[SIZE_2]] : index
+// CHECK: [[C8:%.*]] = constant 8 : index
+// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[C8]] : index
// CHECK: [[C1_:%.*]] = constant 1 : index
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
@@ -108,6 +106,6 @@ func @memref_reshape(%input: memref<*xf32>,
// CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index
// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
-// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]],
+// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
-// CHECK-SAME: : memref<*xf32> to memref<?x?x?xf32>
+// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
More information about the Mlir-commits
mailing list