[Mlir-commits] [mlir] cd93935 - [mlir][MemRef] Make sure types match when folding dim(reshape)
Benjamin Kramer
llvmlistbot at llvm.org
Tue Jun 15 03:50:21 PDT 2021
Author: Benjamin Kramer
Date: 2021-06-15T12:33:44+02:00
New Revision: cd939351467643a80490d036408b1036d39b9814
URL: https://github.com/llvm/llvm-project/commit/cd939351467643a80490d036408b1036d39b9814
DIFF: https://github.com/llvm/llvm-project/commit/cd939351467643a80490d036408b1036d39b9814.diff
LOG: [mlir][MemRef] Make sure types match when folding dim(reshape)
Reshape can take integer types in addition to index, but dim always
returns index.
Differential Revision: https://reviews.llvm.org/D104287
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fe1a8e94b7c48..b9f4dc91634bc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -770,8 +770,11 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
rewriter.setInsertionPointAfter(reshape);
- rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
- llvm::makeArrayRef({dim.index()}));
+ Location loc = dim.getLoc();
+ Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
+ if (load.getType() != dim.getType())
+ load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
+ rewriter.replaceOp(dim, load);
return success();
}
};
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index d3e48d4c7edd2..24db1d295ffc3 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -122,6 +122,26 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
// -----
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_i32(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
+// CHECK-NEXT: %[[IDX:.*]] = constant 3
+// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]]
+// CHECK-NOT: memref.dim
+// CHECK: return %[[CAST]] : index
+func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
+ -> index {
+ %c3 = constant 3 : index
+ %0 = memref.reshape %arg0(%arg1)
+ : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
+ %1 = memref.dim %0, %c3 : memref<*xf32>
+ return %1 : index
+}
+
+// -----
+
// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
// CHECK-LABEL: func @fold_dim_of_tensor.cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
More information about the Mlir-commits
mailing list