[Mlir-commits] [mlir] 3598605 - [mlir][std] Fold dim(dynamic_tensor_from_elements, %cst)
Stephan Herhut
llvmlistbot at llvm.org
Tue Nov 17 05:40:14 PST 2020
Author: Stephan Herhut
Date: 2020-11-17T14:39:59+01:00
New Revision: 3598605c0b3658dbb6cac634cb92a0a131f2fe0b
URL: https://github.com/llvm/llvm-project/commit/3598605c0b3658dbb6cac634cb92a0a131f2fe0b
DIFF: https://github.com/llvm/llvm-project/commit/3598605c0b3658dbb6cac634cb92a0a131f2fe0b.diff
LOG: [mlir][std] Fold dim(dynamic_tensor_from_elements, %cst)
The shape of the result of a dynamic_tensor_from_elements is defined via its
result type and operands. We already fold dim operations when they reference
one of the statically sized dimensions. Now, also fold dim on the dynamically
sized dimensions by picking the corresponding operand.
Differential Revision: https://reviews.llvm.org/D91616
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 629c8e4e6a5f..469889321805 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1484,6 +1484,25 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return getResult();
}
+ // Fold dim to the operand of dynamic_tensor_from_elements.
+ if (auto fromElements =
+ dyn_cast_or_null<DynamicTensorFromElementsOp>(definingOp)) {
+ auto resultType =
+ fromElements.getResult().getType().cast<RankedTensorType>();
+ // The case where the type encodes the size of the dimension is handled
+ // above.
+ assert(resultType.getShape()[index.getInt()] ==
+ RankedTensorType::kDynamicSize);
+
+ // Find the operand of the fromElements that corresponds to this index.
+ auto dynExtents = fromElements.dynamicExtents().begin();
+ for (auto dim : resultType.getShape().take_front(index.getInt()))
+ if (dim == RankedTensorType::kDynamicSize)
+ dynExtents++;
+
+ return Value{*dynExtents};
+ }
+
// Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
auto memrefType = argTy.dyn_cast<MemRefType>();
if (!memrefType)
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 1589dc1af90e..1e2e4a5bf116 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -5,9 +5,9 @@
// CHECK-SAME: %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: return %[[TENSOR]]
func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
- %0 = tensor_to_memref %arg0 : memref<?xf32>
- %1 = tensor_load %0 : memref<?xf32>
- return %1 : tensor<?xf32>
+ %0 = tensor_to_memref %arg0 : memref<?xf32>
+ %1 = tensor_load %0 : memref<?xf32>
+ return %1 : tensor<?xf32>
}
// Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m
@@ -15,9 +15,9 @@ func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
// CHECK: return %[[MEMREF]]
func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
- %0 = tensor_load %arg0 : memref<?xf32>
- %1 = tensor_to_memref %0 : memref<?xf32>
- return %1 : memref<?xf32>
+ %0 = tensor_load %arg0 : memref<?xf32>
+ %1 = tensor_to_memref %0 : memref<?xf32>
+ return %1 : memref<?xf32>
}
// Test case: If the memrefs are not the same type, don't fold them.
@@ -27,9 +27,9 @@ func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32, 7>
// CHECK: return %[[MEMREF_ADDRSPACE7]]
func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf32, 7> {
- %0 = tensor_load %arg0 : memref<?xf32, 2>
- %1 = tensor_to_memref %0 : memref<?xf32, 7>
- return %1 : memref<?xf32, 7>
+ %0 = tensor_load %arg0 : memref<?xf32, 2>
+ %1 = tensor_to_memref %0 : memref<?xf32, 7>
+ return %1 : memref<?xf32, 7>
}
// Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
@@ -39,8 +39,23 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
// CHECK: %[[D:.*]] = dim %[[MEMREF]], %[[C0]]
// CHECK: return %[[D]] : index
func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
- %c0 = constant 0 : index
- %0 = tensor_load %arg0 : memref<?xf32>
- %1 = dim %0, %c0 : tensor<?xf32>
- return %1 : index
+ %c0 = constant 0 : index
+ %0 = tensor_load %arg0 : memref<?xf32>
+ %1 = dim %0, %c0 : tensor<?xf32>
+ return %1 : index
+}
+
+// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
+// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-NOT: dim
+// CHECK: return %[[IDX1]] : index
+func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
+ %c3 = constant 3 : index
+ %0 = dynamic_tensor_from_elements %arg0, %arg1 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ yield %c3 : index
+ } : tensor<2x?x4x?x5xindex>
+ %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
+ return %1 : index
}
More information about the Mlir-commits
mailing list