[Mlir-commits] [mlir] Bug fix in folder of MLIR op 'tensor.extract' (PR #75109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 11 14:35:44 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rafael Ubal (rafaelubalmw)
<details>
<summary>Changes</summary>
The folder for `tensor.extract` is not operating correctly when it is consuming the result of a `tensor.from_elements` operation.
The existing unit test named `@<!-- -->extract_from_tensor.from_elements_3d` in `mlir/test/Dialect/Tensor/canonicalize.mlir` seems an attempt to stress this code. However, this unit tests creates a `tensor.from_elements` op exclusively from constants, which gets folded away into a single constant tensor. Therefore, the buggy code was never executed in unit tests.
I have added a new unit test named `@<!-- -->extract_from_tensor.from_elements_variable_3d` that makes sure the `tensor.from_elements` op is not folded away by having its input operands come directly from function arguments. The original folder code would have made this test fail.
---
Full diff: https://github.com/llvm/llvm-project/pull/75109.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-2)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+44)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ec4c41c0000a9c..a257e5f4d9dc22 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1116,9 +1116,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
- if (i < rank - 1)
- stride *= tensorType.getDimSize(i);
flatIndex += indices[i] * stride;
+ stride *= tensorType.getDimSize(i);
}
// Prevent out of bounds accesses. This can happen in invalid code that
// will never execute.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7d7d221c1e8e96..8542fc9567412b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -242,6 +242,50 @@ func.func @extract_from_tensor.from_elements_3d()
// -----
+// CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
+func.func @extract_from_tensor.from_elements_variable_3d(
+ %f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
+ %f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+
+ %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+ : tensor<3x2x2xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+
+ %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
+ %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
+ %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
+ %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
+ %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
+ %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
+ %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
+ %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
+ %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
+ %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
+ %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
+ %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
+ return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+ : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+// CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
+// CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
+
+// -----
+
// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
// CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
// CHECK-NEXT: return %cst : tensor<3xcomplex<i32>>
``````````
</details>
https://github.com/llvm/llvm-project/pull/75109
More information about the Mlir-commits
mailing list