[Mlir-commits] [mlir] Bug fix in folder of MLIR op 'tensor.extract' (PR #75109)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 14:35:44 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rafael Ubal (rafaelubalmw)

<details>
<summary>Changes</summary>

The folder for `tensor.extract` is not operating correctly when it is consuming the result of a `tensor.from_elements` operation.

The existing unit test named `@<!-- -->extract_from_tensor.from_elements_3d` in `mlir/test/Dialect/Tensor/canonicalize.mlir` seems an attempt to stress this code. However, this unit tests creates a `tensor.from_elements` op exclusively from constants, which gets folded away into a single constant tensor. Therefore, the buggy code was never executed in unit tests.

I have added a new unit test named `@<!-- -->extract_from_tensor.from_elements_variable_3d` that makes sure the `tensor.from_elements` op is not folded away by having its input operands come directly from function arguments. The original folder code would have made this test fail.

---
Full diff: https://github.com/llvm/llvm-project/pull/75109.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-2) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+44) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ec4c41c0000a9c..a257e5f4d9dc22 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1116,9 +1116,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
     int flatIndex = 0;
     int stride = 1;
     for (int i = rank - 1; i >= 0; --i) {
-      if (i < rank - 1)
-        stride *= tensorType.getDimSize(i);
       flatIndex += indices[i] * stride;
+      stride *= tensorType.getDimSize(i);
     }
     // Prevent out of bounds accesses. This can happen in invalid code that
     // will never execute.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7d7d221c1e8e96..8542fc9567412b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -242,6 +242,50 @@ func.func @extract_from_tensor.from_elements_3d()
 
 // -----
 
+// CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
+// CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
+func.func @extract_from_tensor.from_elements_variable_3d(
+    %f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
+    %f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
+    -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+
+  %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
+         : tensor<3x2x2xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+
+  %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
+  %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
+  %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
+  %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
+  %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
+  %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
+  %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
+  %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
+  %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
+  %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
+  %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
+  %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
+  return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
+         : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
+}
+// CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
+// CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
+
+// -----
+
 // CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
 // CHECK-NEXT:  %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
 // CHECK-NEXT:  return %cst : tensor<3xcomplex<i32>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/75109


More information about the Mlir-commits mailing list