[Mlir-commits] [mlir] cd93935 - [mlir][MemRef] Make sure types match when folding dim(reshape)

Benjamin Kramer llvmlistbot at llvm.org
Tue Jun 15 03:50:21 PDT 2021


Author: Benjamin Kramer
Date: 2021-06-15T12:33:44+02:00
New Revision: cd939351467643a80490d036408b1036d39b9814

URL: https://github.com/llvm/llvm-project/commit/cd939351467643a80490d036408b1036d39b9814
DIFF: https://github.com/llvm/llvm-project/commit/cd939351467643a80490d036408b1036d39b9814.diff

LOG: [mlir][MemRef] Make sure types match when folding dim(reshape)

Reshape can take integer types in addition to index, but dim always
returns index.

Differential Revision: https://reviews.llvm.org/D104287

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index fe1a8e94b7c48..b9f4dc91634bc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -770,8 +770,11 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
     // Place the load directly after the reshape to ensure that the shape memref
     // was not mutated.
     rewriter.setInsertionPointAfter(reshape);
-    rewriter.replaceOpWithNewOp<LoadOp>(dim, reshape.shape(),
-                                        llvm::makeArrayRef({dim.index()}));
+    Location loc = dim.getLoc();
+    Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
+    if (load.getType() != dim.getType())
+      load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
+    rewriter.replaceOp(dim, load);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index d3e48d4c7edd2..24db1d295ffc3 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -122,6 +122,26 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
 
 // -----
 
+// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
+// CHECK-LABEL: func @dim_of_memref_reshape_i32(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<*xf32>,
+//  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: memref<?xi32>
+//  CHECK-NEXT:   %[[IDX:.*]] = constant 3
+//  CHECK-NEXT:   %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
+//  CHECK-NEXT:   %[[CAST:.*]] = index_cast %[[DIM]]
+//   CHECK-NOT:   memref.dim
+//       CHECK:   return %[[CAST]] : index
+func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
+    -> index {
+  %c3 = constant 3 : index
+  %0 = memref.reshape %arg0(%arg1)
+      : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
+  %1 = memref.dim %0, %c3 : memref<*xf32>
+  return %1 : index
+}
+
+// -----
+
 // Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx
 // CHECK-LABEL: func @fold_dim_of_tensor.cast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>


        


More information about the Mlir-commits mailing list