[Mlir-commits] [mlir] [mlir][linalg] Add test for masked vector.gather lowering in tensor.extract (PR #76298)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 23 10:41:47 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Prathamesh Tagore (meshtag)

<details>
<summary>Changes</summary>

This PR adds a test to cover the case when a masked `vector.gather` is [generated](https://github.com/llvm/llvm-project/blob/acacec3bbf4586ef9bc6c4f31707d3515d5215a1/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L1026) during the lowering of `tensor.extract`. This was previously not covered. 


---
Full diff: https://github.com/llvm/llvm-project/pull/76298.diff


1 Files Affected:

- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+49) 


``````````diff
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 3fd4fcd536624c..f1ac557a684b21 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -550,3 +550,52 @@ module attributes {transform.with_named_sequence} {
      transform.yield
    }
 }
+
+// -----
+
+func.func @vectorize_nd_tensor_extract_masked_gather(%arg0: tensor<?xf32>, 
+                                                     %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>],
+                   iterator_types = ["parallel"] }
+    outs(%arg1 : tensor<?xf32>) {
+    ^bb(%out: f32) :
+      %1 = tensor.extract %arg0[%c0] : tensor<?xf32>
+      linalg.yield %1 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_masked_gather(
+//  CHECK-SAME:              %[[ARG0:.+]]: tensor<?xf32>, %[[ARG1:.+]]: tensor<?xf32>) -> tensor<?xf32> {
+//       CHECK:     %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:     %[[C0_0:.+]] = arith.constant 0 : index
+//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0_0]] : tensor<?xf32>
+//       CHECK:     %[[C0_1:.+]] = arith.constant 0 : index
+//       CHECK:     %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:     %[[VAL0:.+]] = vector.create_mask %[[DIM]] : vector<[4]xi1>
+//       CHECK:     %[[VAL1:.+]] = vector.mask %[[VAL0]] 
+//  CHECK-SAME:       { vector.transfer_read %[[ARG1]][%[[C0_1]]], %[[CST]] {in_bounds = [true]} :
+//  CHECK-SAME:       tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+//       CHECK:     %[[CST_2:.+]] = arith.constant dense<true> : vector<[4]xi1>
+//       CHECK:     %[[CST_3:.+]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+//       CHECK:     %[[C0_4:.+]] = arith.constant 0 : index
+//       CHECK:     %[[CST_5:.+]] = arith.constant dense<0> : vector<[4]xindex>
+//       CHECK:     %[[VAL2:.+]] = vector.mask %[[VAL0]] { vector.gather %[[ARG0]][%[[C0_4]]] 
+//  CHECK-SAME:       [%[[CST_5]]], %[[CST_2]], %[[CST_3]] : tensor<?xf32>, vector<[4]xindex>, 
+//  CHECK-SAME:       vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> } : 
+//  CHECK-SAME:       vector<[4]xi1> -> vector<[4]xf32>
+//       CHECK:     %[[C0_6:.+]] = arith.constant 0 : index
+//       CHECK:     %[[VAL3:.+]] = vector.mask %[[VAL0]] { vector.transfer_write %[[VAL2]], 
+//  CHECK-SAME:       %[[ARG1]][%[[C0_6]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : 
+//  CHECK-SAME:       vector<[4]xi1> -> tensor<?xf32>
+//       CHECK:     return %[[VAL3]] : tensor<?xf32>
+//       CHECK:   }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/76298


More information about the Mlir-commits mailing list