[Mlir-commits] [mlir] [mlir][linalg] Add test for masked vector.gather lowering in tensor.extract (PR #76298)

Prathamesh Tagore llvmlistbot at llvm.org
Sat Dec 23 10:41:17 PST 2023


https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/76298

This PR adds a test to cover the case when a masked `vector.gather` is [generated](https://github.com/llvm/llvm-project/blob/acacec3bbf4586ef9bc6c4f31707d3515d5215a1/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L1026) during the lowering of `tensor.extract`. This was previously not covered. 


>From 6b8c3e3af5b43a66ced52e9c57a29f4850982774 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 23 Dec 2023 18:35:14 +0000
Subject: [PATCH] Add test

---
 .../Linalg/vectorize-tensor-extract.mlir      | 49 +++++++++++++++++++
 1 file changed, 49 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 3fd4fcd536624c..f1ac557a684b21 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -550,3 +550,52 @@ module attributes {transform.with_named_sequence} {
      transform.yield
    }
 }
+
+// -----
+
+func.func @vectorize_nd_tensor_extract_masked_gather(%arg0: tensor<?xf32>, 
+                                                     %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>],
+                   iterator_types = ["parallel"] }
+    outs(%arg1 : tensor<?xf32>) {
+    ^bb(%out: f32) :
+      %1 = tensor.extract %arg0[%c0] : tensor<?xf32>
+      linalg.yield %1 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_masked_gather(
+//  CHECK-SAME:              %[[ARG0:.+]]: tensor<?xf32>, %[[ARG1:.+]]: tensor<?xf32>) -> tensor<?xf32> {
+//       CHECK:     %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:     %[[C0_0:.+]] = arith.constant 0 : index
+//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0_0]] : tensor<?xf32>
+//       CHECK:     %[[C0_1:.+]] = arith.constant 0 : index
+//       CHECK:     %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//       CHECK:     %[[VAL0:.+]] = vector.create_mask %[[DIM]] : vector<[4]xi1>
+//       CHECK:     %[[VAL1:.+]] = vector.mask %[[VAL0]] 
+//  CHECK-SAME:       { vector.transfer_read %[[ARG1]][%[[C0_1]]], %[[CST]] {in_bounds = [true]} :
+//  CHECK-SAME:       tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+//       CHECK:     %[[CST_2:.+]] = arith.constant dense<true> : vector<[4]xi1>
+//       CHECK:     %[[CST_3:.+]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+//       CHECK:     %[[C0_4:.+]] = arith.constant 0 : index
+//       CHECK:     %[[CST_5:.+]] = arith.constant dense<0> : vector<[4]xindex>
+//       CHECK:     %[[VAL2:.+]] = vector.mask %[[VAL0]] { vector.gather %[[ARG0]][%[[C0_4]]] 
+//  CHECK-SAME:       [%[[CST_5]]], %[[CST_2]], %[[CST_3]] : tensor<?xf32>, vector<[4]xindex>, 
+//  CHECK-SAME:       vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> } : 
+//  CHECK-SAME:       vector<[4]xi1> -> vector<[4]xf32>
+//       CHECK:     %[[C0_6:.+]] = arith.constant 0 : index
+//       CHECK:     %[[VAL3:.+]] = vector.mask %[[VAL0]] { vector.transfer_write %[[VAL2]], 
+//  CHECK-SAME:       %[[ARG1]][%[[C0_6]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : 
+//  CHECK-SAME:       vector<[4]xi1> -> tensor<?xf32>
+//       CHECK:     return %[[VAL3]] : tensor<?xf32>
+//       CHECK:   }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list