[Mlir-commits] [mlir] 9cb9081 - [mlir][vector] Extend vector.gather e2e test (#187071)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 18 02:23:21 PDT 2026


Author: Andrzej WarzyƄski
Date: 2026-03-18T09:23:17Z
New Revision: 9cb9081049a4cc4b049175591d4bb60012f149ba

URL: https://github.com/llvm/llvm-project/commit/9cb9081049a4cc4b049175591d4bb60012f149ba
DIFF: https://github.com/llvm/llvm-project/commit/9cb9081049a4cc4b049175591d4bb60012f149ba.diff

LOG: [mlir][vector] Extend vector.gather e2e test (#187071)

Extend the vector.gather e2e test to cover both available lowering
paths:

* Direct lowering to LLVM (via -test-lower-to-llvm)
* Lowering via vector.load (via -test-vector-gather-lowering)

This is a follow-up to https://github.com/llvm/llvm-project/pull/184706,
which updated a pattern used by -test-vector-gather-lowering.

The test is extended to operate on 2D memrefs so that the changes
in https://github.com/llvm/llvm-project/pull/184706 are meaningfully
exercised.

Added: 
    

Modified: 
    mlir/test/Integration/Dialect/Vector/CPU/gather.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/gather.mlir b/mlir/test/Integration/Dialect/Vector/CPU/gather.mlir
index ab2e713e83a6c..daf641d03eb75 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/gather.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/gather.mlir
@@ -1,13 +1,28 @@
-// RUN: mlir-opt %s -test-lower-to-llvm  | \
-// RUN: mlir-runner -e entry -entry-point-result=void \
-// RUN:   -shared-libs=%mlir_c_runner_utils | \
-// RUN: FileCheck %s
+// DEFINE: %{entry_point} = main
+// DEFINE: %{run} = mlir-runner -e entry -entry-point-result=void \
+// DEFINE:         -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils
 
-func.func @gather8(%base: memref<?xf32>, %indices: vector<8xi32>,
+/// TEST 1. Verify default compilation (direct lowering of `vector.gather` to LLVM)
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-llvm
+// RUN: %{compile} | %{run} | FileCheck %s
+
+/// TEST 2. Verify compilation via `test-vector-gather-lowering` (`vector.gather`
+/// lowerd to LLVM via `vector.load`)
+// REDEFINE: %{compile} = mlir-opt %s --test-vector-gather-lowering | mlir-opt -test-lower-to-llvm
+// RUN: %{compile} | %{run} | FileCheck %s
+
+/// TEST 3. Verify that `test-vector-gather-lowering` will indeed produce
+/// `vector.load`
+// REDEFINE: %{compile} = mlir-opt %s --test-vector-gather-lowering
+// RUN: %{compile} | FileCheck %s -check-prefix CHECK-IR 
+
+func.func @gather8(%base: memref<?x?xf32>, %indices: vector<8xi32>,
               %mask: vector<8xi1>, %pass_thru: vector<8xf32>) -> vector<8xf32> {
   %c0 = arith.constant 0: index
-  %g = vector.gather %base[%c0][%indices], %mask, %pass_thru
-    : memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  /// Verify that the lowering via vector.load does indeed generate vector.load
+  // CHECK-IR-COUNT-4: vector.load
+  %g = vector.gather %base[%c0, %c0][%indices], %mask, %pass_thru
+    : memref<?x?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
   return %g : vector<8xf32>
 }
 
@@ -16,29 +31,36 @@ func.func @entry() {
   %c0 = arith.constant 0: index
   %c1 = arith.constant 1: index
   %c10 = arith.constant 10: index
-  %A = memref.alloc(%c10) : memref<?xf32>
+  %c5 = arith.constant 5: index
+  %A = memref.alloc(%c10, %c5) : memref<?x?xf32>
   scf.for %i = %c0 to %c10 step %c1 {
-    %i32 = arith.index_cast %i : index to i32
-    %fi = arith.sitofp %i32 : i32 to f32
-    memref.store %fi, %A[%i] : memref<?xf32>
+    scf.for %j = %c0 to %c5 step %c1 {
+      %off = arith.muli %i, %c10 : index
+      %val_index = arith.addi %j, %off : index
+      %val_i32 = arith.index_cast %val_index : index to i32
+      %val = arith.sitofp %val_i32 : i32 to f32
+      memref.store %val, %A[%i, %j] : memref<?x?xf32>
+    }
   }
+  %A_cast = memref.cast %A : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%A_cast) : (memref<*xf32>) -> ()
 
   // Set up idx vector.
   %i0 = arith.constant 0: i32
-  %i1 = arith.constant 1: i32
-  %i2 = arith.constant 2: i32
-  %i3 = arith.constant 3: i32
-  %i4 = arith.constant 4: i32
-  %i5 = arith.constant 5: i32
-  %i6 = arith.constant 6: i32
-  %i9 = arith.constant 9: i32
   %0 = vector.broadcast %i0 : i32 to vector<8xi32>
+  %i6 = arith.constant 16: i32
   %1 = vector.insert %i6, %0[1] : i32 into vector<8xi32>
+  %i1 = arith.constant 11: i32
   %2 = vector.insert %i1, %1[2] : i32 into vector<8xi32>
+  %i3 = arith.constant 33: i32
   %3 = vector.insert %i3, %2[3] : i32 into vector<8xi32>
+  %i5 = arith.constant 5: i32
   %4 = vector.insert %i5, %3[4] : i32 into vector<8xi32>
+  %i4 = arith.constant 44: i32
   %5 = vector.insert %i4, %4[5] : i32 into vector<8xi32>
+  %i9 = arith.constant 19: i32
   %6 = vector.insert %i9, %5[6] : i32 into vector<8xi32>
+  %i2 = arith.constant 22: i32
   %idx = vector.insert %i2, %6[7] : i32 into vector<8xi32>
 
   // Set up pass thru vector.
@@ -57,35 +79,30 @@ func.func @entry() {
   //
 
   %g1 = call @gather8(%A, %idx, %all, %pass)
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    : (memref<?x?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g1 : vector<8xf32>
-  // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
+  // CHECK: ( 0, 31, 21, 63, 10, 84, 34, 42 )
 
   %g2 = call @gather8(%A, %idx, %none, %pass)
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    : (memref<?x?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g2 : vector<8xf32>
   // CHECK: ( -7, -7, -7, -7, -7, -7, -7, -7 )
 
   %g3 = call @gather8(%A, %idx, %some, %pass)
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    : (memref<?x?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g3 : vector<8xf32>
-  // CHECK: ( 0, 6, 1, 3, -7, -7, -7, -7 )
+  // CHECK: ( 0, 31, 21, 63, -7, -7, -7, -7 )
 
   %g4 = call @gather8(%A, %idx, %more, %pass)
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
+    : (memref<?x?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
     -> (vector<8xf32>)
   vector.print %g4 : vector<8xf32>
-  // CHECK: ( 0, 6, 1, 3, -7, -7, -7, 2 )
-
-  %g5 = call @gather8(%A, %idx, %all, %pass)
-    : (memref<?xf32>, vector<8xi32>, vector<8xi1>, vector<8xf32>)
-    -> (vector<8xf32>)
-  vector.print %g5 : vector<8xf32>
-  // CHECK: ( 0, 6, 1, 3, 5, 4, 9, 2 )
+  // CHECK: ( 0, 31, 21, 63, -7, -7, -7, 42 )
 
-  memref.dealloc %A : memref<?xf32>
+  memref.dealloc %A : memref<?x?xf32>
   return
 }
+func.func private @printMemrefF32(%ptr : memref<*xf32>)


        


More information about the Mlir-commits mailing list