[Mlir-commits] [mlir] 85ff270 - [mlir][std] Add DimOp folding for dim(tensor_load(m)) -> dim(m).
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Nov 4 05:10:25 PST 2020
Author: Nicolas Vasilache
Date: 2020-11-04T13:06:22Z
New Revision: 85ff2705cdea60e3cf8fc49af7588c78638ca04f
URL: https://github.com/llvm/llvm-project/commit/85ff2705cdea60e3cf8fc49af7588c78638ca04f
DIFF: https://github.com/llvm/llvm-project/commit/85ff2705cdea60e3cf8fc49af7588c78638ca04f.diff
LOG: [mlir][std] Add DimOp folding for dim(tensor_load(m)) -> dim(m).
Differential Revision: https://reviews.llvm.org/D90755
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 666955c3c4c7..8c9a1bba3125 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1561,23 +1561,29 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
}
}
+ Operation *definingOp = memrefOrTensor().getDefiningOp();
+ // dim(tensor_load(memref)) -> dim(memref)
+ if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
+ setOperand(0, tensorLoadOp.memref());
+ return getResult();
+ }
+
// Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
auto memrefType = argTy.dyn_cast<MemRefType>();
if (!memrefType)
return {};
// The size at the given index is now known to be a dynamic size of a memref.
- auto *memref = memrefOrTensor().getDefiningOp();
unsigned unsignedIndex = index.getValue().getZExtValue();
- if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
+ if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
return *(alloc.getDynamicSizes().begin() +
memrefType.getDynamicDimIndex(unsignedIndex));
- if (auto view = dyn_cast_or_null<ViewOp>(memref))
+ if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
return *(view.getDynamicSizes().begin() +
memrefType.getDynamicDimIndex(unsignedIndex));
- if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
+ if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
assert(subview.isDynamicSize(unsignedIndex) &&
"Expected dynamic subview size");
return subview.getDynamicSize(unsignedIndex);
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index cd22014e0de0..1589dc1af90e 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -31,3 +31,16 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
%1 = tensor_to_memref %0 : memref<?xf32, 7>
return %1 : memref<?xf32, 7>
}
+
+// Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
+// CHECK-LABEL: func @dim_of_tensor_load(
+// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
+// CHECK: %[[C0:.*]] = constant 0
+// CHECK: %[[D:.*]] = dim %[[MEMREF]], %[[C0]]
+// CHECK: return %[[D]] : index
+func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
+ %c0 = constant 0 : index
+ %0 = tensor_load %arg0 : memref<?xf32>
+ %1 = dim %0, %c0 : tensor<?xf32>
+ return %1 : index
+}
More information about the Mlir-commits
mailing list