[Mlir-commits] [mlir] a8f3860 - [mlir][tensor] Fix bug in `tensor.extract(tensor.from_elements)` folder (#75109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 12 07:36:57 PST 2023
Author: Rafael Ubal
Date: 2023-12-12T15:36:52Z
New Revision: a8f3860bcb644d724275f51e7db7b291c7a3f4df
URL: https://github.com/llvm/llvm-project/commit/a8f3860bcb644d724275f51e7db7b291c7a3f4df
DIFF: https://github.com/llvm/llvm-project/commit/a8f3860bcb644d724275f51e7db7b291c7a3f4df.diff
LOG: [mlir][tensor] Fix bug in `tensor.extract(tensor.from_elements)` folder (#75109)
The folder for `tensor.extract` is not operating correctly when it is
consuming the result of a `tensor.from_elements` operation.
The existing unit test named `@extract_from_tensor.from_elements_3d` in
`mlir/test/Dialect/Tensor/canonicalize.mlir` seems an attempt to stress
this code. However, this unit tests creates a `tensor.from_elements` op
exclusively from constants, which gets folded away into a single
constant tensor. Therefore, the buggy code was never executed in unit
tests.
I have added a new unit test named
`@extract_from_tensor.from_elements_variable_3d` that makes sure the
`tensor.from_elements` op is not folded away by having its input
operands come directly from function arguments. The original folder code
would have made this test fail.
This bug was notably affecting the lowering of the `tosa.pad` op in the
`tosa-to-tensor` pass, where the generated code is likely to contain a
`tensor.from_elements` + `tensor.extract` op sequence.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ec4c41c0000a9c..a257e5f4d9dc22 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1116,9 +1116,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
- if (i < rank - 1)
- stride *= tensorType.getDimSize(i);
flatIndex += indices[i] * stride;
+ stride *= tensorType.getDimSize(i);
}
// Prevent out of bounds accesses. This can happen in invalid code that
// will never execute.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7d7d221c1e8e96..8542fc9567412b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -242,6 +242,50 @@ func.func @extract_from_tensor.from_elements_3d()
// -----
+// CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
+func.func @extract_from_tensor.from_elements_variable_3d(
+ %f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
+ %f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
+ -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+
+ %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+ : tensor<3x2x2xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+
+ %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
+ %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
+ %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
+ %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
+ %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
+ %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
+ %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
+ %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
+ %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
+ %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
+ %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
+ %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
+ return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+ : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+// CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
+// CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
+
+// -----
+
// CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
// CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
// CHECK-NEXT: return %cst : tensor<3xcomplex<i32>>
More information about the Mlir-commits
mailing list