[Mlir-commits] [mlir] [mlir][vector] Fix dominance error in warp vector distribution (PR #77771)

Matthias Springer llvmlistbot at llvm.org
Thu Jan 11 05:57:38 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/77771

This commit fixes a test in `vector-warp-distribute.mlir` when `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` is enabled.

```
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: error: operand #0 does not dominate this use
    %1 = vector.extract %0[9] : f32 from vector<64xf32>
         ^
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: see current operation: %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: operand defined here (op in a child region)
"func.func"() <{function_type = (index) -> f32, sym_name = "vector_extract_1d"}> ({
^bb0(%arg0: index):
  %0:2 = "vector.warp_execute_on_lane_0"(%arg0) <{warp_size = 32 : i64}> ({
    %7 = "some_def"() : () -> vector<64xf32>
    %8 = "arith.constant"() <{value = 9 : index}> : () -> index
    %9 = "vector.extractelement"(%7, %8) : (vector<64xf32>, index) -> f32
    "vector.yield"(%9, %7) : (f32, vector<64xf32>) -> ()
  }) : (index) -> (f32, vector<2xf32>)
  %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
  %2 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 mod 2)>}> : (index) -> index
  %3 = "vector.extractelement"(%0#1, %2) : (vector<2xf32>, index) -> f32
  %4 = "arith.index_cast"(%1) : (index) -> i32
  %5 = "arith.constant"() <{value = 32 : i32}> : () -> i32
  %6:2 = "gpu.shuffle"(%3, %4, %5) <{mode = #gpu<shuffle_mode idx>}> : (f32, i32, i32) -> (f32, i1)
  "func.return"(%6#0) : (f32) -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```

>From 804e2044fbb80609143238a35309230c648cde44 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 11 Jan 2024 13:53:07 +0000
Subject: [PATCH] [mlir][vector] Fix dominance error in warp vector
 distribution

This commit fixes `vector-warp-distribute.mlir` when `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` is enabled.

```
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: error: operand #0 does not dominate this use
    %1 = vector.extract %0[9] : f32 from vector<64xf32>
         ^
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: see current operation: %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: operand defined here (op in a child region)
"func.func"() <{function_type = (index) -> f32, sym_name = "vector_extract_1d"}> ({
^bb0(%arg0: index):
  %0:2 = "vector.warp_execute_on_lane_0"(%arg0) <{warp_size = 32 : i64}> ({
    %7 = "some_def"() : () -> vector<64xf32>
    %8 = "arith.constant"() <{value = 9 : index}> : () -> index
    %9 = "vector.extractelement"(%7, %8) : (vector<64xf32>, index) -> f32
    "vector.yield"(%9, %7) : (f32, vector<64xf32>) -> ()
  }) : (index) -> (f32, vector<2xf32>)
  %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
  %2 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 mod 2)>}> : (index) -> index
  %3 = "vector.extractelement"(%0#1, %2) : (vector<2xf32>, index) -> f32
  %4 = "arith.index_cast"(%1) : (index) -> i32
  %5 = "arith.constant"() <{value = 32 : i32}> : () -> i32
  %6:2 = "gpu.shuffle"(%3, %4, %5) <{mode = #gpu<shuffle_mode idx>}> : (f32, i32, i32) -> (f32, i1)
  "func.return"(%6#0) : (f32) -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```
---
 .../Vector/Transforms/VectorDistribute.cpp     | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 074356ab425377..ec6f1dea2f5454 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1329,11 +1329,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     } else {
       distributedVecType = extractSrcType;
     }
-    // Yield source vector from warp op.
+    // Yield source vector and position (if present) from warp op.
+    SmallVector<Value> additionalResults{extractOp.getVector()};
+    SmallVector<Type> additionalResultTypes{distributedVecType};
+    if (static_cast<bool>(extractOp.getPosition())) {
+      additionalResults.push_back(extractOp.getPosition());
+      additionalResultTypes.push_back(extractOp.getPosition().getType());
+    }
     Location loc = extractOp.getLoc();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
+        rewriter, warpOp, additionalResults, additionalResultTypes,
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
@@ -1362,14 +1368,16 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
     // tid of extracting thread: pos / elementsPerLane
     Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
-        loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
+        loc, sym0.ceilDiv(elementsPerLane),
+        newWarpOp->getResult(newRetIndices[1]));
     // Extract at position: pos % elementsPerLane
     Value pos =
         elementsPerLane == 1
             ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
             : rewriter
-                  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
-                                                 extractOp.getPosition())
+                  .create<affine::AffineApplyOp>(
+                      loc, sym0 % elementsPerLane,
+                      newWarpOp->getResult(newRetIndices[1]))
                   .getResult();
     Value extracted =
         rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);



More information about the Mlir-commits mailing list