[Mlir-commits] [mlir] 9931ee6 - [mlir][vector] Fix FlattenGather for scalable vectors (#96074)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 24 00:36:10 PDT 2024
Author: Cullen Rhodes
Date: 2024-06-24T08:36:06+01:00
New Revision: 9931ee61d99c101db653ae21706f1edce4b39781
URL: https://github.com/llvm/llvm-project/commit/9931ee61d99c101db653ae21706f1edce4b39781
DIFF: https://github.com/llvm/llvm-project/commit/9931ee61d99c101db653ae21706f1edce4b39781.diff
LOG: [mlir][vector] Fix FlattenGather for scalable vectors (#96074)
This pattern flattens vector.gather ops by unrolling the outermost
dimension for rank > 2 vectors. There's two issues with this pattern for
scalable vectors:
1. The unrolling doesn't take vscale into account. A constraint is
added to disable this pattern for vectors with leading scalable
dims.
2. The scalable dims are dropped when creating the new gather. Fixed
by propagating the flags.
Depends on #96049.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
mlir/test/Dialect/Vector/vector-gather-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index dd027d107d16a..a0df03c7e808b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -55,6 +55,8 @@ namespace {
/// ```
///
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
+///
+/// Supports vector types with a fixed leading dimension.
struct FlattenGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
@@ -64,6 +66,11 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
if (resultTy.getRank() < 2)
return rewriter.notifyMatchFailure(op, "already flat");
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (resultTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
Location loc = op.getLoc();
Value indexVec = op.getIndexVec();
Value maskVec = op.getMask();
@@ -72,8 +79,7 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
Value result = rewriter.create<arith::ConstantOp>(
loc, resultTy, rewriter.getZeroAttr(resultTy));
- Type subTy = VectorType::get(resultTy.getShape().drop_front(),
- resultTy.getElementType());
+ VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
int64_t thisIdx[1] = {i};
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index c2eb88afa4dbf..5ad3a23e0ba15 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -74,6 +74,43 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: @scalable_gather_memref_2d
+// CHECK-SAME: %[[BASE:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[IDXVEC:.*]]: vector<2x[3]xindex>,
+// CHECK-SAME: %[[MASK:.*]]: vector<2x[3]xi1>,
+// CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
+// CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
+// CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
+// CHECK: %[[GATHER0:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC0]]], %[[MASK0]], %[[PASS0]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+// CHECK: %[[INS0:.*]] = vector.insert %[[GATHER0]], %[[INIT]] [0] : vector<[3]xf32> into vector<2x[3]xf32>
+// CHECK: %[[IDXVEC1:.*]] = vector.extract %[[IDXVEC]][1] : vector<[3]xindex> from vector<2x[3]xindex>
+// CHECK: %[[MASK1:.*]] = vector.extract %[[MASK]][1] : vector<[3]xi1> from vector<2x[3]xi1>
+// CHECK: %[[PASS1:.*]] = vector.extract %[[PASS]][1] : vector<[3]xf32> from vector<2x[3]xf32>
+// CHECK: %[[GATHER1:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC1]]], %[[MASK1]], %[[PASS1]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
+// CHECK: %[[INS1:.*]] = vector.insert %[[GATHER1]], %[[INS0]] [1] : vector<[3]xf32> into vector<2x[3]xf32>
+// CHECK-NEXT: return %[[INS1]] : vector<2x[3]xf32>
+func.func @scalable_gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x[3]xindex>, %mask: vector<2x[3]xi1>, %pass_thru: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x[3]xindex>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
+// CHECK-LABEL: @scalable_gather_cant_unroll
+// CHECK-NOT: extract
+// CHECK: vector.gather
+// CHECK-NOT: extract
+func.func @scalable_gather_cant_unroll(%base: memref<?x?xf32>, %v: vector<[4]x8xindex>, %mask: vector<[4]x8xi1>, %pass_thru: vector<[4]x8xf32>) -> vector<[4]x8xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<[4]x8xindex>, vector<[4]x8xi1>, vector<[4]x8xf32> into vector<[4]x8xf32>
+ return %0 : vector<[4]x8xf32>
+}
+
// CHECK-LABEL: @gather_tensor_1d
// CHECK-SAME: ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0] : i1 from vector<2xi1>
More information about the Mlir-commits
mailing list