[Mlir-commits] [mlir] 85ff270 - [mlir][std] Add DimOp folding for dim(tensor_load(m)) -> dim(m).

Nicolas Vasilache llvmlistbot at llvm.org
Wed Nov 4 05:10:25 PST 2020


Author: Nicolas Vasilache
Date: 2020-11-04T13:06:22Z
New Revision: 85ff2705cdea60e3cf8fc49af7588c78638ca04f

URL: https://github.com/llvm/llvm-project/commit/85ff2705cdea60e3cf8fc49af7588c78638ca04f
DIFF: https://github.com/llvm/llvm-project/commit/85ff2705cdea60e3cf8fc49af7588c78638ca04f.diff

LOG: [mlir][std] Add DimOp folding for dim(tensor_load(m)) -> dim(m).

Differential Revision: https://reviews.llvm.org/D90755

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 666955c3c4c7..8c9a1bba3125 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1561,23 +1561,29 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
     }
   }
 
+  Operation *definingOp = memrefOrTensor().getDefiningOp();
+  // dim(tensor_load(memref)) -> dim(memref)
+  if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
+    setOperand(0, tensorLoadOp.memref());
+    return getResult();
+  }
+
   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
   auto memrefType = argTy.dyn_cast<MemRefType>();
   if (!memrefType)
     return {};
 
   // The size at the given index is now known to be a dynamic size of a memref.
-  auto *memref = memrefOrTensor().getDefiningOp();
   unsigned unsignedIndex = index.getValue().getZExtValue();
-  if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
+  if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
     return *(alloc.getDynamicSizes().begin() +
              memrefType.getDynamicDimIndex(unsignedIndex));
 
-  if (auto view = dyn_cast_or_null<ViewOp>(memref))
+  if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
     return *(view.getDynamicSizes().begin() +
              memrefType.getDynamicDimIndex(unsignedIndex));
 
-  if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
+  if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
     assert(subview.isDynamicSize(unsignedIndex) &&
            "Expected dynamic subview size");
     return subview.getDynamicSize(unsignedIndex);

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index cd22014e0de0..1589dc1af90e 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -31,3 +31,16 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
     %1 = tensor_to_memref %0 : memref<?xf32, 7>
     return %1 : memref<?xf32, 7>
 }
+
+// Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
+// CHECK-LABEL: func @dim_of_tensor_load(
+//  CHECK-SAME:     %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
+//       CHECK:   %[[C0:.*]] = constant 0
+//       CHECK:   %[[D:.*]] = dim %[[MEMREF]], %[[C0]]
+//       CHECK:   return %[[D]] : index
+func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
+    %c0 = constant 0 : index
+    %0 = tensor_load %arg0 : memref<?xf32>
+    %1 = dim %0, %c0 : tensor<?xf32>
+    return %1 : index
+}


        


More information about the Mlir-commits mailing list