[Mlir-commits] [mlir] cc145f4 - [mlir][vector] Disable Gather1DToConditionalLoads for scalable vectors (#96049)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 20 00:07:45 PDT 2024


Author: Cullen Rhodes
Date: 2024-06-20T08:07:43+01:00
New Revision: cc145f40530667d65220536a3e03eabe9fdd46cf

URL: https://github.com/llvm/llvm-project/commit/cc145f40530667d65220536a3e03eabe9fdd46cf
DIFF: https://github.com/llvm/llvm-project/commit/cc145f40530667d65220536a3e03eabe9fdd46cf.diff

LOG: [mlir][vector] Disable Gather1DToConditionalLoads for scalable vectors (#96049)

Pattern scalarizes vector.gather operations and is incorrect for
scalable vectors.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
    mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 90128126d0fa1..dd027d107d16a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -189,6 +189,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     if (resultTy.getRank() != 1)
       return rewriter.notifyMatchFailure(op, "unsupported rank");
 
+    if (resultTy.isScalable())
+      return rewriter.notifyMatchFailure(op, "not a fixed-width vector");
+
     Location loc = op.getLoc();
     Type elemTy = resultTy.getElementType();
     // Vector type with a single element. Used to generate `vector.loads`.

diff  --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index d047ac629d87e..c2eb88afa4dbf 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -206,3 +206,13 @@ func.func @strided_gather(%base : memref<100x3xf32>,
 // CHECK:           scf.if %[[MASK_3]] -> (vector<4xf32>)
 // CHECK:             %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
 // CHECK:             %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
+
+// CHECK-LABEL: @scalable_gather_1d
+// CHECK-NOT: extract
+// CHECK: vector.gather
+// CHECK-NOT: extract
+func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask: vector<[2]xi1>, %pass_thru: vector<[2]xf32>) -> vector<[2]xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}


        


More information about the Mlir-commits mailing list