[Mlir-commits] [mlir] [mlir][vector] Fix dominance error in warp vector distribution (PR #77771)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 11 05:58:06 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit fixes a test in `vector-warp-distribute.mlir` when `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` is enabled.
```
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: error: operand #<!-- -->0 does not dominate this use
%1 = vector.extract %0[9] : f32 from vector<64xf32>
^
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: see current operation: %1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
within split at /usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Vector/vector-warp-distribute.mlir:1 offset :18:10: note: operand defined here (op in a child region)
"func.func"() <{function_type = (index) -> f32, sym_name = "vector_extract_1d"}> ({
^bb0(%arg0: index):
%0:2 = "vector.warp_execute_on_lane_0"(%arg0) <{warp_size = 32 : i64}> ({
%7 = "some_def"() : () -> vector<64xf32>
%8 = "arith.constant"() <{value = 9 : index}> : () -> index
%9 = "vector.extractelement"(%7, %8) : (vector<64xf32>, index) -> f32
"vector.yield"(%9, %7) : (f32, vector<64xf32>) -> ()
}) : (index) -> (f32, vector<2xf32>)
%1 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 ceildiv 2)>}> : (index) -> index
%2 = "affine.apply"(%8) <{map = affine_map<()[s0] -> (s0 mod 2)>}> : (index) -> index
%3 = "vector.extractelement"(%0#<!-- -->1, %2) : (vector<2xf32>, index) -> f32
%4 = "arith.index_cast"(%1) : (index) -> i32
%5 = "arith.constant"() <{value = 32 : i32}> : () -> i32
%6:2 = "gpu.shuffle"(%3, %4, %5) <{mode = #gpu<shuffle_mode idx>}> : (f32, i32, i32) -> (f32, i1)
"func.return"(%6#<!-- -->0) : (f32) -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```
---
Full diff: https://github.com/llvm/llvm-project/pull/77771.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+13-5)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 074356ab425377..ec6f1dea2f5454 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1329,11 +1329,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
} else {
distributedVecType = extractSrcType;
}
- // Yield source vector from warp op.
+ // Yield source vector and position (if present) from warp op.
+ SmallVector<Value> additionalResults{extractOp.getVector()};
+ SmallVector<Type> additionalResultTypes{distributedVecType};
+ if (static_cast<bool>(extractOp.getPosition())) {
+ additionalResults.push_back(extractOp.getPosition());
+ additionalResultTypes.push_back(extractOp.getPosition().getType());
+ }
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
+ rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
@@ -1362,14 +1368,16 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
- loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
+ loc, sym0.ceilDiv(elementsPerLane),
+ newWarpOp->getResult(newRetIndices[1]));
// Extract at position: pos % elementsPerLane
Value pos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
: rewriter
- .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
- extractOp.getPosition())
+ .create<affine::AffineApplyOp>(
+ loc, sym0 % elementsPerLane,
+ newWarpOp->getResult(newRetIndices[1]))
.getResult();
Value extracted =
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
``````````
</details>
https://github.com/llvm/llvm-project/pull/77771
More information about the Mlir-commits
mailing list