[Mlir-commits] [mlir] [mlir][linalg] Add test for masked vector.gather lowering in tensor.extract (PR #76298)
Prathamesh Tagore
llvmlistbot at llvm.org
Sat Dec 23 10:41:17 PST 2023
https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/76298
This PR adds a test to cover the case when a masked `vector.gather` is [generated](https://github.com/llvm/llvm-project/blob/acacec3bbf4586ef9bc6c4f31707d3515d5215a1/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L1026) during the lowering of `tensor.extract`. This was previously not covered.
>From 6b8c3e3af5b43a66ced52e9c57a29f4850982774 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Sat, 23 Dec 2023 18:35:14 +0000
Subject: [PATCH] Add test
---
.../Linalg/vectorize-tensor-extract.mlir | 49 +++++++++++++++++++
1 file changed, 49 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 3fd4fcd536624c..f1ac557a684b21 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -550,3 +550,52 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @vectorize_nd_tensor_extract_masked_gather(%arg0: tensor<?xf32>,
+ %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"] }
+ outs(%arg1 : tensor<?xf32>) {
+ ^bb(%out: f32) :
+ %1 = tensor.extract %arg0[%c0] : tensor<?xf32>
+ linalg.yield %1 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_masked_gather(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>, %[[ARG1:.+]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C0_0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0_0]] : tensor<?xf32>
+// CHECK: %[[C0_1:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL0:.+]] = vector.create_mask %[[DIM]] : vector<[4]xi1>
+// CHECK: %[[VAL1:.+]] = vector.mask %[[VAL0]]
+// CHECK-SAME: { vector.transfer_read %[[ARG1]][%[[C0_1]]], %[[CST]] {in_bounds = [true]} :
+// CHECK-SAME: tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[CST_2:.+]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK: %[[CST_3:.+]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+// CHECK: %[[C0_4:.+]] = arith.constant 0 : index
+// CHECK: %[[CST_5:.+]] = arith.constant dense<0> : vector<[4]xindex>
+// CHECK: %[[VAL2:.+]] = vector.mask %[[VAL0]] { vector.gather %[[ARG0]][%[[C0_4]]]
+// CHECK-SAME: [%[[CST_5]]], %[[CST_2]], %[[CST_3]] : tensor<?xf32>, vector<[4]xindex>,
+// CHECK-SAME: vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> } :
+// CHECK-SAME: vector<[4]xi1> -> vector<[4]xf32>
+// CHECK: %[[C0_6:.+]] = arith.constant 0 : index
+// CHECK: %[[VAL3:.+]] = vector.mask %[[VAL0]] { vector.transfer_write %[[VAL2]],
+// CHECK-SAME: %[[ARG1]][%[[C0_6]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } :
+// CHECK-SAME: vector<[4]xi1> -> tensor<?xf32>
+// CHECK: return %[[VAL3]] : tensor<?xf32>
+// CHECK: }
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list