[Mlir-commits] [mlir] [mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. (PR #147620)
Charitha Saumya
llvmlistbot at llvm.org
Fri Jul 11 12:47:52 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/147620
>From 586839035dfaaf45d66ad6b1184f94465c10906f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 2 Jul 2025 17:29:08 +0000
Subject: [PATCH 01/10] working but bug in dead result
---
.../Vector/Transforms/VectorDistribute.cpp | 66 ++++++++++++++-----
1 file changed, 50 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index af90ed8f5deaf..28c957bf61921 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
@@ -1777,24 +1778,42 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (llvm::is_contained(distTypes, Type{}))
return failure();
+ llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n";
+
+ SmallVector<Value> yieldedValuesFromWarpOp;
+ // All init args of the forOp are yielded from the original warp op.
+ for (Value initArg : forOp.getInitArgs()) {
+ yieldedValuesFromWarpOp.push_back(initArg);
+ // find distributed type for the init arg.
+ Type distType = initArg.getType();
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
+ AffineMap map = distributionMapFn(initArg);
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
+ distTypes.push_back(distType);
+ }
+ // All escaping values are yielded from the original warp op.
+ yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
+ escapingValues.begin(),
+ escapingValues.end());
+
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
- newRetIndices);
+ rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
yield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
SmallVector<Value> newOperands;
SmallVector<unsigned> resultIdx;
- // Collect all the outputs coming from the forOp.
+ // Collect the new init args coming from the new warp op.
+ for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
+ newOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
- auto forResult = cast<OpResult>(yieldOperand.get());
- newOperands.push_back(
- newWarpOp.getResult(yieldOperand.getOperandNumber()));
+ OpResult forResult = cast<OpResult>(yieldOperand.get());
+ resultIdx.push_back(forResult.getResultNumber());
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
- resultIdx.push_back(yieldOperand.getOperandNumber());
}
OpBuilder::InsertionGuard g(rewriter);
@@ -1812,8 +1831,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
- for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
- warpInput.push_back(newWarpOp.getResult(retIdx));
+ for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) {
+ warpInput.push_back(newWarpOp.getResult(i));
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
}
@@ -1826,24 +1845,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
for (Value args : innerWarp.getBody()->getArguments()) {
argMapping.push_back(args);
}
- argMapping.resize(forOp.getBody()->getNumArguments());
+ auto forOpCopy = cast<scf::ForOp>(rewriter.clone(*forOp.getOperation()));
+ argMapping.resize(forOpCopy.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
- for (Value operand : forOp.getBody()->getTerminator()->getOperands())
+ for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands())
yieldOperands.push_back(operand);
- rewriter.eraseOp(forOp.getBody()->getTerminator());
- rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+ rewriter.eraseOp(forOpCopy.getBody()->getTerminator());
+ rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping);
rewriter.setInsertionPointToEnd(innerWarp.getBody());
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
if (!innerWarp.getResults().empty())
- rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
- rewriter.eraseOp(forOp);
+ rewriter.create<scf::YieldOp>(forOpCopy.getLoc(), innerWarp.getResults());
+ // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
+ // llvm::outs() << "\n";
+ // llvm::errs() << "erasing for op\n";
+
+ rewriter.eraseOp(forOpCopy);
// Replace the warpOp result coming from the original ForOp.
+ // print resultIdx for debugging.
+ llvm::errs() << "resultIdx: ";
+ for (auto idx : resultIdx)
+ llvm::errs() << idx << " ";
+ llvm::errs() << "\n";
for (const auto &res : llvm::enumerate(resultIdx)) {
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
newForOp.getResult(res.index()));
- newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
+ // newForOp->setOperand(res.index() + 3,
+ // newWarpOp.getResult(res.value()));
}
+ rewriter.eraseOp(forOp);
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
@@ -1852,6 +1884,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
}
});
+ newForOp->getParentOp()->print(llvm::outs());
+ llvm::outs() << "\n";
// Finally, hoist out any now uniform code from the inner warp op.
mlir::vector::moveScalarUniformCode(innerWarp);
>From 4c363175e0c5a0d6cddfe7ac3532051f76d88039 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 2 Jul 2025 17:59:39 +0000
Subject: [PATCH 02/10] working version
---
.../Vector/Transforms/VectorDistribute.cpp | 40 ++++++++++++-------
1 file changed, 25 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 28c957bf61921..dae62d2cecc04 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1796,26 +1796,34 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
escapingValues.begin(),
escapingValues.end());
-
- SmallVector<size_t> newRetIndices;
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
- rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices);
- yield = cast<gpu::YieldOp>(
- newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-
- SmallVector<Value> newOperands;
+ // record result mapping.
SmallVector<unsigned> resultIdx;
- // Collect the new init args coming from the new warp op.
- for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
- newOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
OpResult forResult = cast<OpResult>(yieldOperand.get());
resultIdx.push_back(forResult.getResultNumber());
- yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
+ // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
}
+ // SmallVector<size_t> newRetIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+ rewriter, warpOp, yieldedValuesFromWarpOp, distTypes);
+ yield = cast<gpu::YieldOp>(
+ newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+
+ SmallVector<Value> newOperands;
+ // Collect the new init args coming from the new warp op.
+ for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
+ newOperands.push_back(newWarpOp.getResult(i));
+ // for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ // if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+ // continue;
+ // OpResult forResult = cast<OpResult>(yieldOperand.get());
+ // resultIdx.push_back(forResult.getResultNumber());
+ // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
+ // }
+
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
@@ -1831,7 +1839,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
- for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) {
+ for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
+ ++i) {
warpInput.push_back(newWarpOp.getResult(i));
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
@@ -1870,12 +1879,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::errs() << idx << " ";
llvm::errs() << "\n";
for (const auto &res : llvm::enumerate(resultIdx)) {
- rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
- newForOp.getResult(res.index()));
+ rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
+ newForOp.getResult(res.index()), newForOp);
// newForOp->setOperand(res.index() + 3,
// newWarpOp.getResult(res.value()));
}
rewriter.eraseOp(forOp);
+ rewriter.eraseOp(warpOp);
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
>From 3595f1758ad2c71f2f265253cf0a66ec3bdc94d2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 2 Jul 2025 22:10:56 +0000
Subject: [PATCH 03/10] working version refined
---
.../Vector/Transforms/VectorDistribute.cpp | 75 +++++++++++++------
1 file changed, 52 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index dae62d2cecc04..52a55d104c0bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1778,37 +1779,53 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (llvm::is_contained(distTypes, Type{}))
return failure();
- llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n";
+ // record result mapping.
+ SmallVector<unsigned> resultIdx;
+ llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
+ for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+ continue;
+ OpResult forResult = cast<OpResult>(yieldOperand.get());
+ resultIdx.push_back(forResult.getResultNumber());
+ forResultToWarpResultMapping[forResult.getResultNumber()] =
+ yieldOperand.getOperandNumber();
+ // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
+ }
+
+ // llvm::errs() << "escpaing values size: " << escapingValues.size() <<
+ // "\n";
SmallVector<Value> yieldedValuesFromWarpOp;
+ SmallVector<Type> yieldedTypesFromWarpOp;
// All init args of the forOp are yielded from the original warp op.
- for (Value initArg : forOp.getInitArgs()) {
+ for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
yieldedValuesFromWarpOp.push_back(initArg);
// find distributed type for the init arg.
Type distType = initArg.getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
- AffineMap map = distributionMapFn(initArg);
- distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ if (forResultToWarpResultMapping.contains(i)) {
+ // If the init arg is yielded from the warp op, we need to compute the
+ // distributed type.
+ distType =
+ warpOp.getResult(forResultToWarpResultMapping[i]).getType();
+ } else {
+ AffineMap map = distributionMapFn(initArg);
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ }
}
- distTypes.push_back(distType);
+ // llvm::errs() << "distributed type: " << distType << "\n";
+ yieldedTypesFromWarpOp.push_back(distType);
}
// All escaping values are yielded from the original warp op.
yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
escapingValues.begin(),
escapingValues.end());
- // record result mapping.
- SmallVector<unsigned> resultIdx;
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
- if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
- continue;
- OpResult forResult = cast<OpResult>(yieldOperand.get());
- resultIdx.push_back(forResult.getResultNumber());
- // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
- }
+ yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
+ distTypes.begin(), distTypes.end());
// SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldedValuesFromWarpOp, distTypes);
+ rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
yield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
@@ -1839,15 +1856,27 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
+ // llvm::errs() << "setting arg index mapping\n";
for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
++i) {
warpInput.push_back(newWarpOp.getResult(i));
- argIndexMapping[escapingValues[i]] = warpInputType.size();
- warpInputType.push_back(inputTypes[i]);
+ argIndexMapping[escapingValues[i - forOp.getInitArgs().size()]] =
+ warpInputType.size();
+ warpInputType.push_back(inputTypes[i - forOp.getInitArgs().size()]);
}
+ // for (auto [i, r] : llvm::enumerate(
+ // newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
+ // {
+ // warpInput.push_back(r);
+ // argIndexMapping[escapingValues[i]] = warpInputType.size();
+ // warpInputType.push_back(inputTypes[i]);
+ // }
+ // llvm::errs() << "go here\n";
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
newWarpOp.getWarpSize(), warpInput, warpInputType);
+ // newForOp->getParentOp()->print(llvm::outs());
+ // llvm::outs() << "\n";
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
@@ -1874,10 +1903,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
rewriter.eraseOp(forOpCopy);
// Replace the warpOp result coming from the original ForOp.
// print resultIdx for debugging.
- llvm::errs() << "resultIdx: ";
- for (auto idx : resultIdx)
- llvm::errs() << idx << " ";
- llvm::errs() << "\n";
+ // llvm::errs() << "resultIdx: ";
+ // for (auto idx : resultIdx)
+ // llvm::errs() << idx << " ";
+ // llvm::errs() << "\n";
for (const auto &res : llvm::enumerate(resultIdx)) {
rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
newForOp.getResult(res.index()), newForOp);
@@ -1894,8 +1923,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
}
});
- newForOp->getParentOp()->print(llvm::outs());
- llvm::outs() << "\n";
+ // newForOp->getParentOp()->print(llvm::outs());
+ // llvm::outs() << "\n";
// Finally, hoist out any now uniform code from the inner warp op.
mlir::vector::moveScalarUniformCode(innerWarp);
>From ba94ee21098465ea466c79acaaf01953b34cfc70 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 8 Jul 2025 00:21:19 +0000
Subject: [PATCH 04/10] working failing case now
---
.../Vector/Transforms/VectorDistribute.cpp | 81 +++++++++++++------
1 file changed, 58 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 52a55d104c0bd..adfed18a625b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -19,8 +19,10 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1779,22 +1781,35 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (llvm::is_contained(distTypes, Type{}))
return failure();
+ SmallVector<Value> nonForYieldedValues;
+ // SmallVector<Type> nonForYieldedTypes;
+ SmallVector<unsigned> nonForResultIndices;
+
// record result mapping.
- SmallVector<unsigned> resultIdx;
- llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
+ DenseMap<unsigned, unsigned> forResultMapping;
+ DenseMap<unsigned, unsigned> warpResultMapping;
+ // llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
for (OpOperand &yieldOperand : yield->getOpOperands()) {
- if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
+ if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+ nonForYieldedValues.push_back(yieldOperand.get());
+ // nonForYieldedTypes.push_back(
+ // warpOp.getResult(yieldOperand.getOperandNumber()).getType());
+ nonForResultIndices.push_back(yieldOperand.getOperandNumber());
continue;
+ }
OpResult forResult = cast<OpResult>(yieldOperand.get());
- resultIdx.push_back(forResult.getResultNumber());
- forResultToWarpResultMapping[forResult.getResultNumber()] =
- yieldOperand.getOperandNumber();
+ forResultMapping[yieldOperand.getOperandNumber()] =
+ forResult.getResultNumber();
+ // forResultToWarpResultMapping[forResult.getResultNumber()] =
+ // yieldOperand.getOperandNumber();
// yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
}
+ // llvm::errs() << "non for yielded values size: "
+ // << nonForYieldedValues.size() << "\n";
+
// llvm::errs() << "escpaing values size: " << escapingValues.size() <<
// "\n";
-
SmallVector<Value> yieldedValuesFromWarpOp;
SmallVector<Type> yieldedTypesFromWarpOp;
// All init args of the forOp are yielded from the original warp op.
@@ -1803,15 +1818,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// find distributed type for the init arg.
Type distType = initArg.getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
- if (forResultToWarpResultMapping.contains(i)) {
- // If the init arg is yielded from the warp op, we need to compute the
- // distributed type.
- distType =
- warpOp.getResult(forResultToWarpResultMapping[i]).getType();
- } else {
- AffineMap map = distributionMapFn(initArg);
- distType = getDistributedType(vecType, map, warpOp.getWarpSize());
- }
+ // if (forResultToWarpResultMapping.contains(i)) {
+ // // If the init arg is yielded from the warp op, we need to compute
+ // the
+ // // distributed type.
+ // distType =
+ // warpOp.getResult(forResultToWarpResultMapping[i]).getType();
+ // } else {
+ AffineMap map = distributionMapFn(initArg);
+ distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+ // }
}
// llvm::errs() << "distributed type: " << distType << "\n";
yieldedTypesFromWarpOp.push_back(distType);
@@ -1823,12 +1839,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
distTypes.begin(), distTypes.end());
+ for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
+ warpResultMapping[nonForResultIndices[i]] =
+ yieldedValuesFromWarpOp.size();
+ yieldedValuesFromWarpOp.push_back(v);
+ yieldedTypesFromWarpOp.push_back(
+ warpOp.getResult(nonForResultIndices[i]).getType());
+ }
+
// SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
yield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ // newWarpOp->print(llvm::outs());
+ // llvm::outs() << "\n";
+
SmallVector<Value> newOperands;
// Collect the new init args coming from the new warp op.
for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
@@ -1857,12 +1884,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
// llvm::errs() << "setting arg index mapping\n";
- for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults();
- ++i) {
+ unsigned escapingValuesStartIdx = forOp.getInitArgs().size();
+ for (size_t i = escapingValuesStartIdx;
+ i < escapingValuesStartIdx + escapingValues.size(); ++i) {
warpInput.push_back(newWarpOp.getResult(i));
- argIndexMapping[escapingValues[i - forOp.getInitArgs().size()]] =
+ argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
warpInputType.size();
- warpInputType.push_back(inputTypes[i - forOp.getInitArgs().size()]);
+ warpInputType.push_back(inputTypes[i - escapingValuesStartIdx]);
}
// for (auto [i, r] : llvm::enumerate(
// newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
@@ -1907,9 +1935,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// for (auto idx : resultIdx)
// llvm::errs() << idx << " ";
// llvm::errs() << "\n";
- for (const auto &res : llvm::enumerate(resultIdx)) {
- rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()),
- newForOp.getResult(res.index()), newForOp);
+ for (auto [origIdx, newIdx] : forResultMapping) {
+ rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ newForOp.getResult(newIdx), newForOp);
+ // newForOp->setOperand(res.index() + 3,
+ // newWarpOp.getResult(res.value()));
+ }
+
+ for (auto [origIdx, newIdx] : warpResultMapping) {
+ rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+ newWarpOp.getResult(newIdx));
// newForOp->setOperand(res.index() + 3,
// newWarpOp.getResult(res.value()));
}
>From 28ef9c9400695acb954c34bc48c9a43b710fb92b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 8 Jul 2025 23:27:46 +0000
Subject: [PATCH 05/10] add comments and tests
---
.../Vector/Transforms/VectorDistribute.cpp | 217 ++++++++----------
.../Vector/vector-warp-distribute.mlir | 79 +++++++
2 files changed, 172 insertions(+), 124 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index adfed18a625b3..b49c2063b075d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1749,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
+ auto newWarpOpYield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
// Only pick up forOp if it is the last op in the region.
- Operation *lastNode = yield->getPrevNode();
+ Operation *lastNode = newWarpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
// Collect Values that come from the warp op but are outside the forOp.
- // Those Value needs to be returned by the original warpOp and passed to
- // the new op.
+ // Those Value needs to be returned by the new warp op.
llvm::SmallSetVector<Value, 32> escapingValues;
- SmallVector<Type> inputTypes;
- SmallVector<Type> distTypes;
+ SmallVector<Type> escapingValueInputTypes;
+ SmallVector<Type> escapingValuedistTypes;
mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1773,183 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
- inputTypes.push_back(operand->get().getType());
- distTypes.push_back(distType);
+ escapingValueInputTypes.push_back(operand->get().getType());
+ escapingValuedistTypes.push_back(distType);
}
});
- if (llvm::is_contained(distTypes, Type{}))
+ if (llvm::is_contained(escapingValuedistTypes, Type{}))
return failure();
-
+ // Warp op can yield two types of values:
+ // 1. Values that are not results of the forOp:
+ // These values must also be yielded by the new warp op. Also, we need to
+ // record the index mapping for these values to replace them later.
+ // 2. Values that are results of the forOp:
+ // In this case, we record the index mapping between the warp op result
+ // index and matching forOp result index.
SmallVector<Value> nonForYieldedValues;
- // SmallVector<Type> nonForYieldedTypes;
SmallVector<unsigned> nonForResultIndices;
-
- // record result mapping.
DenseMap<unsigned, unsigned> forResultMapping;
- DenseMap<unsigned, unsigned> warpResultMapping;
- // llvm::SmallDenseMap<unsigned, unsigned> forResultToWarpResultMapping;
- for (OpOperand &yieldOperand : yield->getOpOperands()) {
+ for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+ // Yielded value is not a result of the forOp.
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
nonForYieldedValues.push_back(yieldOperand.get());
- // nonForYieldedTypes.push_back(
- // warpOp.getResult(yieldOperand.getOperandNumber()).getType());
nonForResultIndices.push_back(yieldOperand.getOperandNumber());
continue;
}
OpResult forResult = cast<OpResult>(yieldOperand.get());
forResultMapping[yieldOperand.getOperandNumber()] =
forResult.getResultNumber();
- // forResultToWarpResultMapping[forResult.getResultNumber()] =
- // yieldOperand.getOperandNumber();
- // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
}
- // llvm::errs() << "non for yielded values size: "
- // << nonForYieldedValues.size() << "\n";
-
- // llvm::errs() << "escpaing values size: " << escapingValues.size() <<
- // "\n";
- SmallVector<Value> yieldedValuesFromWarpOp;
- SmallVector<Type> yieldedTypesFromWarpOp;
- // All init args of the forOp are yielded from the original warp op.
+ // Newly created warp op will yield values in following order:
+ // 1. All init args of the forOp.
+ // 2. All escaping values.
+ // 3. All non-for yielded values.
+ SmallVector<Value> newWarpOpYieldValues;
+ SmallVector<Type> newWarpOpDistTypes;
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
- yieldedValuesFromWarpOp.push_back(initArg);
- // find distributed type for the init arg.
+ newWarpOpYieldValues.push_back(initArg);
+ // Compute the distributed type for this init arg.
Type distType = initArg.getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
- // if (forResultToWarpResultMapping.contains(i)) {
- // // If the init arg is yielded from the warp op, we need to compute
- // the
- // // distributed type.
- // distType =
- // warpOp.getResult(forResultToWarpResultMapping[i]).getType();
- // } else {
AffineMap map = distributionMapFn(initArg);
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
- // }
}
- // llvm::errs() << "distributed type: " << distType << "\n";
- yieldedTypesFromWarpOp.push_back(distType);
+ newWarpOpDistTypes.push_back(distType);
}
- // All escaping values are yielded from the original warp op.
- yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(),
- escapingValues.begin(),
- escapingValues.end());
- yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(),
- distTypes.begin(), distTypes.end());
-
+ // Insert escaping values and their distributed types.
+ newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+ escapingValues.begin(), escapingValues.end());
+ newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+ escapingValuedistTypes.begin(),
+ escapingValuedistTypes.end());
+ // Next, we insert all non-for yielded values and their distributed types.
+ // We also create a mapping between the non-for yielded value index and the
+ // corresponding new warp op yield value index (needed to update users
+ // later).
+ DenseMap<unsigned, unsigned> warpResultMapping;
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
- warpResultMapping[nonForResultIndices[i]] =
- yieldedValuesFromWarpOp.size();
- yieldedValuesFromWarpOp.push_back(v);
- yieldedTypesFromWarpOp.push_back(
+ warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+ newWarpOpYieldValues.push_back(v);
+ newWarpOpDistTypes.push_back(
warpOp.getResult(nonForResultIndices[i]).getType());
}
-
- // SmallVector<size_t> newRetIndices;
+ // Create the new warp op with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp);
- yield = cast<gpu::YieldOp>(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ newWarpOpYield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- // newWarpOp->print(llvm::outs());
- // llvm::outs() << "\n";
-
- SmallVector<Value> newOperands;
- // Collect the new init args coming from the new warp op.
- for (size_t i = 0; i < forOp.getInitArgs().size(); ++i)
- newOperands.push_back(newWarpOp.getResult(i));
- // for (OpOperand &yieldOperand : yield->getOpOperands()) {
- // if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
- // continue;
- // OpResult forResult = cast<OpResult>(yieldOperand.get());
- // resultIdx.push_back(forResult.getResultNumber());
- // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
- // }
+ // Next, we create a new for op with the init args yielded by the new
+ // warp op.
+ unsigned escapingValuesStartIdx =
+ forOp.getInitArgs().size(); // ForOp init args are positioned before
+ // escaping values in the new warp op.
+ SmallVector<Value> newForOpOperands;
+ for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+ newForOpOperands.push_back(newWarpOp.getResult(i));
+ // Create a new for op outside the new warp op region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
-
- // Create a new for op outside the region with a WarpExecuteOnLane0Op
- // region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newOperands);
+ forOp.getStep(), newForOpOperands);
+ // Next, we insert a new warp op (called inner warp op) inside the
+ // newly created for op. This warp op will contain all ops that were
+ // contained within the original for op body.
rewriter.setInsertionPointToStart(newForOp.getBody());
- SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
- newForOp.getRegionIterArgs().end());
- SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
- forOp.getResultTypes().end());
+ SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
+ newForOp.getRegionIterArgs().end());
+ SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
+ forOp.getResultTypes().end());
+ // Escaping values are forwarded to the inner warp op as its (additional)
+ // arguments. We keep track of the mapping between these values and their
+ // argument index in the inner warp op (to replcace uses later).
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
- // llvm::errs() << "setting arg index mapping\n";
- unsigned escapingValuesStartIdx = forOp.getInitArgs().size();
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
- warpInput.push_back(newWarpOp.getResult(i));
+ innerWarpInput.push_back(newWarpOp.getResult(i));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
- warpInputType.size();
- warpInputType.push_back(inputTypes[i - escapingValuesStartIdx]);
+ innerWarpInputType.size();
+ innerWarpInputType.push_back(
+ escapingValueInputTypes[i - escapingValuesStartIdx]);
}
- // for (auto [i, r] : llvm::enumerate(
- // newWarpOp.getResults().drop_front(forOp.getInitArgs().size())))
- // {
- // warpInput.push_back(r);
- // argIndexMapping[escapingValues[i]] = warpInputType.size();
- // warpInputType.push_back(inputTypes[i]);
- // }
- // llvm::errs() << "go here\n";
+ // Create the inner warp op with the new input values and types.
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
- newWarpOp.getWarpSize(), warpInput, warpInputType);
- // newForOp->getParentOp()->print(llvm::outs());
- // llvm::outs() << "\n";
+ newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
+ // Inline the for op body into the inner warp op body.
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
- for (Value args : innerWarp.getBody()->getArguments()) {
+ for (Value args : innerWarp.getBody()->getArguments())
argMapping.push_back(args);
- }
- auto forOpCopy = cast<scf::ForOp>(rewriter.clone(*forOp.getOperation()));
- argMapping.resize(forOpCopy.getBody()->getNumArguments());
+
+ argMapping.resize(forOp.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
- for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands())
+ for (Value operand : forOp.getBody()->getTerminator()->getOperands())
yieldOperands.push_back(operand);
- rewriter.eraseOp(forOpCopy.getBody()->getTerminator());
- rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping);
+ rewriter.eraseOp(forOp.getBody()->getTerminator());
+ rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+ // Insert a gpu yieldOp at the end of the inner warp op body that yields
+ // original forOp results.
rewriter.setInsertionPointToEnd(innerWarp.getBody());
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
+ // Insert a scf.yield op at the end of the new for op body that yields
+ // the inner warp op results.
if (!innerWarp.getResults().empty())
- rewriter.create<scf::YieldOp>(forOpCopy.getLoc(), innerWarp.getResults());
- // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs());
- // llvm::outs() << "\n";
- // llvm::errs() << "erasing for op\n";
-
- rewriter.eraseOp(forOpCopy);
- // Replace the warpOp result coming from the original ForOp.
- // print resultIdx for debugging.
- // llvm::errs() << "resultIdx: ";
- // for (auto idx : resultIdx)
- // llvm::errs() << idx << " ";
- // llvm::errs() << "\n";
- for (auto [origIdx, newIdx] : forResultMapping) {
+ rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+
+ // Update the users of original warp op results that were coming from the
+ // original forOp to the corresponding new forOp result.
+ for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
- // newForOp->setOperand(res.index() + 3,
- // newWarpOp.getResult(res.value()));
- }
-
- for (auto [origIdx, newIdx] : warpResultMapping) {
+ // Similarly, update any users of the warp op results that were not
+ // results of the forOp.
+ for (auto [origIdx, newIdx] : warpResultMapping)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
- // newForOp->setOperand(res.index() + 3,
- // newWarpOp.getResult(res.value()));
- }
+ // Remove the original warp op and for op, they should not have any uses
+ // at this point.
rewriter.eraseOp(forOp);
rewriter.eraseOp(warpOp);
+ // Update any users of escaping values that were forwarded to the
+ // inner warp op. These values are now arguments of the inner warp op.
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
@@ -1958,8 +1929,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
}
});
- // newForOp->getParentOp()->print(llvm::outs());
- // llvm::outs() << "\n";
// Finally, hoist out any now uniform code from the inner warp op.
mlir::vector::moveScalarUniformCode(innerWarp);
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 7cfbcdf101d11..3982783c764df 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
return
}
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
+// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+// CHECK-PROP: }
+// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+// CHECK-PROP: }
+// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_for_result(%arg0: index) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+ %ini = "some_def"() : () -> (vector<128xf32>)
+ %ini1 = "some_def"() : () -> (vector<128xf32>)
+ %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+ %add = arith.addi %arg3, %c1 : index
+ %1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+ %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+ }
+ gpu.yield %3#0 : vector<128xf32>
+ }
+ "some_use"(%0) : (vector<4xf32>) -> ()
+ return
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
+// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
+// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+// CHECK-PROP-NEXT: }
+// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
+// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
+// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
+// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+// CHECK-PROP-NEXT: }
+// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
+// CHECK-PROP-NEXT: }
+// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
+// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
+// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
+func.func @warp_scf_for_swapped_for_results(%arg0: index) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
+ %ini1 = "some_def"() : () -> (vector<256xf32>)
+ %ini2 = "some_def"() : () -> (vector<128xf32>)
+ %ini3 = "some_def"() : () -> (vector<128xf32>)
+ %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
+ %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
+ %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+ %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
+ scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
+ }
+ gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
+ }
+ "some_use_1"(%0#0) : (vector<4xf32>) -> ()
+ "some_use_2"(%0#1) : (vector<4xf32>) -> ()
+ "some_use_3"(%0#2) : (vector<8xf32>) -> ()
+ return
+}
+
// -----
// CHECK-PROP-LABEL: func @vector_reduction(
>From 537ca0e285c9a220a0dc0d53e24f51c86e81d5e7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 02:46:33 +0000
Subject: [PATCH 06/10] add missing logic
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 24 ++++++++++++++++---
1 file changed, 21 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c072557c2bd22..ef257307de569 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -34,6 +34,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
namespace mlir {
namespace xegpu {
@@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Step 3: Apply subgroup to workitem distribution patterns.
RewritePatternSet patterns(&getContext());
xegpu::populateXeGPUSubgroupDistributePatterns(patterns);
- // TODO: distributionFn and shuffleFn are not used at this point.
+ // distributionFn is used by vector distribution patterns to determine the
+ // distributed vector type for a given vector value. In XeGPU subgroup
+ // distribution context, we compute this based on lane layout.
auto distributionFn = [](Value val) {
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
- OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
- return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
+ // Get the layout of the vector type.
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+ // If no layout is specified, assume the inner most dimension is distributed
+ // for now.
+ if (!layout)
+ return AffineMap::getMultiDimMapWithTargets(
+ vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+ SmallVector<unsigned int> distributedDims;
+ // Get the distributed dimensions based on the layout.
+ ArrayRef<int> laneLayout = layout.getLaneLayout().asArrayRef();
+ for (unsigned i = 0; i < laneLayout.size(); ++i) {
+ if (laneLayout[i] > 1)
+ distributedDims.push_back(i);
+ }
+ return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
+ val.getContext());
};
+ // TODO: shuffleFn is not used.
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx,
int64_t warpSz) { return Value(); };
vector::populatePropagateWarpVectorDistributionPatterns(
>From 164e9d619880ae24388234edcd9297b66c31d209 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 9 Jul 2025 17:57:13 +0000
Subject: [PATCH 07/10] address comments
---
.../Vector/Transforms/VectorDistribute.cpp | 88 +++++++++----------
1 file changed, 44 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3ce134fe5f3ce..7d3d6b98666a1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1751,13 +1751,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
PatternRewriter &rewriter) const override {
auto newWarpOpYield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- // Only pick up forOp if it is the last op in the region.
+ // Only pick up `ForOp` if it is the last op in the region.
Operation *lastNode = newWarpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
- // Collect Values that come from the warp op but are outside the forOp.
- // Those Value needs to be returned by the new warp op.
+ // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
+ // Those Values need to be returned by the new warp op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> escapingValueInputTypes;
SmallVector<Type> escapingValuedistTypes;
@@ -1779,16 +1779,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (llvm::is_contained(escapingValuedistTypes, Type{}))
return failure();
- // Warp op can yield two types of values:
- // 1. Values that are not results of the forOp:
- // These values must also be yielded by the new warp op. Also, we need to
- // record the index mapping for these values to replace them later.
- // 2. Values that are results of the forOp:
- // In this case, we record the index mapping between the warp op result
- // index and matching forOp result index.
+ // `WarpOp` can yield two types of values:
+ // 1. Values that are not results of the `ForOp`:
+ // These values must also be yielded by the new `WarpOp`. Also, we need
+ // to record the index mapping for these values to replace them later.
+ // 2. Values that are results of the `ForOp`:
+ // In this case, we record the index mapping between the `WarpOp` result
+ // index and matching `ForOp` result index.
SmallVector<Value> nonForYieldedValues;
SmallVector<unsigned> nonForResultIndices;
- DenseMap<unsigned, unsigned> forResultMapping;
+ llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
// Yielded value is not a result of the forOp.
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
@@ -1801,10 +1801,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
forResult.getResultNumber();
}
- // Newly created warp op will yield values in following order:
- // 1. All init args of the forOp.
+ // Newly created `WarpOp` will yield values in following order:
+ // 1. All init args of the `ForOp`.
// 2. All escaping values.
- // 3. All non-for yielded values.
+ // 3. All non-`ForOp` yielded values.
SmallVector<Value> newWarpOpYieldValues;
SmallVector<Type> newWarpOpDistTypes;
for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
@@ -1823,50 +1823,50 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
escapingValuedistTypes.begin(),
escapingValuedistTypes.end());
- // Next, we insert all non-for yielded values and their distributed types.
- // We also create a mapping between the non-for yielded value index and the
- // corresponding new warp op yield value index (needed to update users
- // later).
- DenseMap<unsigned, unsigned> warpResultMapping;
+ // Next, we insert all non-`ForOp` yielded values and their distributed
+ // types. We also create a mapping between the non-`ForOp` yielded value
+ // index and the corresponding new `WarpOp` yield value index (needed to
+ // update users later).
+ llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(
warpOp.getResult(nonForResultIndices[i]).getType());
}
- // Create the new warp op with the updated yield values and types.
+ // Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
newWarpOpYield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- // Next, we create a new for op with the init args yielded by the new
- // warp op.
- unsigned escapingValuesStartIdx =
- forOp.getInitArgs().size(); // ForOp init args are positioned before
- // escaping values in the new warp op.
+ // Next, we create a new `ForOp` with the init args yielded by the new
+ // `WarpOp`.
+ const unsigned escapingValuesStartIdx =
+ forOp.getInitArgs().size(); // `ForOp` init args are positioned before
+ // escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
newForOpOperands.push_back(newWarpOp.getResult(i));
- // Create a new for op outside the new warp op region.
+ // Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newForOpOperands);
- // Next, we insert a new warp op (called inner warp op) inside the
- // newly created for op. This warp op will contain all ops that were
- // contained within the original for op body.
+ // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
+ // newly created `ForOp`. This `WarpOp` will contain all ops that were
+ // contained within the original `ForOp` body.
rewriter.setInsertionPointToStart(newForOp.getBody());
SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
newForOp.getRegionIterArgs().end());
SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
- // Escaping values are forwarded to the inner warp op as its (additional)
+ // Escaping values are forwarded to the inner `WarpOp` as its (additional)
// arguments. We keep track of the mapping between these values and their
- // argument index in the inner warp op (to replcace uses later).
+ // argument index in the inner `WarpOp` (to replace users later).
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
@@ -1876,12 +1876,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
innerWarpInputType.push_back(
escapingValueInputTypes[i - escapingValuesStartIdx]);
}
- // Create the inner warp op with the new input values and types.
+ // Create the inner `WarpOp` with the new input values and types.
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
- // Inline the for op body into the inner warp op body.
+ // Inline the `ForOp` body into the inner `WarpOp` body.
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
for (Value args : innerWarp.getBody()->getArguments())
@@ -1895,32 +1895,32 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
rewriter.eraseOp(forOp.getBody()->getTerminator());
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
- // Insert a gpu yieldOp at the end of the inner warp op body that yields
- // original forOp results.
+ // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
+ // original `ForOp` results.
rewriter.setInsertionPointToEnd(innerWarp.getBody());
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
- // Insert a scf.yield op at the end of the new for op body that yields
- // the inner warp op results.
+ // Insert a scf.yield op at the end of the new `ForOp` body that yields
+ // the inner `WarpOp` results.
if (!innerWarp.getResults().empty())
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
- // Update the users of original warp op results that were coming from the
- // original forOp to the corresponding new forOp result.
+ // Update the users of original `WarpOp` results that were coming from the
+ // original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
- // Similarly, update any users of the warp op results that were not
- // results of the forOp.
+ // Similarly, update any users of the `WarpOp` results that were not
+ // results of the `ForOp`.
for (auto [origIdx, newIdx] : warpResultMapping)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
- // Remove the original warp op and for op, they should not have any uses
+ // Remove the original `WarpOp` and `ForOp`, they should not have any uses
// at this point.
rewriter.eraseOp(forOp);
rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
- // inner warp op. These values are now arguments of the inner warp op.
+ // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
@@ -1930,7 +1930,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
}
});
- // Finally, hoist out any now uniform code from the inner warp op.
+ // Finally, hoist out any now uniform code from the inner `WarpOp`.
mlir::vector::moveScalarUniformCode(innerWarp);
return success();
}
>From 683fad8092e6c40a1bc95bf2f249653236f72960 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 10 Jul 2025 22:18:19 +0000
Subject: [PATCH 08/10] address comments
---
.../lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7d3d6b98666a1..f4928ee6c4221 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1760,7 +1760,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Those Values need to be returned by the new warp op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> escapingValueInputTypes;
- SmallVector<Type> escapingValuedistTypes;
+ SmallVector<Type> escapingValueDistTypes;
mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1773,11 +1773,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
escapingValueInputTypes.push_back(operand->get().getType());
- escapingValuedistTypes.push_back(distType);
+ escapingValueDistTypes.push_back(distType);
}
});
- if (llvm::is_contained(escapingValuedistTypes, Type{}))
+ if (llvm::is_contained(escapingValueDistTypes, Type{}))
return failure();
// `WarpOp` can yield two types of values:
// 1. Values that are not results of the `ForOp`:
@@ -1821,8 +1821,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
escapingValues.begin(), escapingValues.end());
newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
- escapingValuedistTypes.begin(),
- escapingValuedistTypes.end());
+ escapingValueDistTypes.begin(),
+ escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
// types. We also create a mapping between the non-`ForOp` yielded value
// index and the corresponding new `WarpOp` yield value index (needed to
>From 2c2703eb9511ff8b2df44de23ffa0efb15308772 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 11 Jul 2025 16:26:37 +0000
Subject: [PATCH 09/10] address comments
---
.../Dialect/Vector/Transforms/VectorDistribute.cpp | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index f4928ee6c4221..af7c34f354668 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1749,10 +1749,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto newWarpOpYield = cast<gpu::YieldOp>(
+ auto warpOpYield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
// Only pick up `ForOp` if it is the last op in the region.
- Operation *lastNode = newWarpOpYield->getPrevNode();
+ Operation *lastNode = warpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
@@ -1789,7 +1789,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
SmallVector<Value> nonForYieldedValues;
SmallVector<unsigned> nonForResultIndices;
llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
- for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
// Yielded value is not a result of the forOp.
if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
nonForYieldedValues.push_back(yieldOperand.get());
@@ -1827,9 +1827,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// types. We also create a mapping between the non-`ForOp` yielded value
// index and the corresponding new `WarpOp` yield value index (needed to
// update users later).
- llvm::SmallDenseMap<unsigned, unsigned> warpResultMapping;
+ llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
- warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+ nonForResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(
warpOp.getResult(nonForResultIndices[i]).getType());
@@ -1837,8 +1837,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
- newWarpOpYield = cast<gpu::YieldOp>(
- newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
@@ -1912,7 +1910,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
newForOp.getResult(newIdx), newForOp);
// Similarly, update any users of the `WarpOp` results that were not
// results of the `ForOp`.
- for (auto [origIdx, newIdx] : warpResultMapping)
+ for (auto [origIdx, newIdx] : nonForResultMapping)
rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
newWarpOp.getResult(newIdx));
// Remove the original `WarpOp` and `ForOp`, they should not have any uses
>From 6297e47a62c8215fb25cd06f165241c10763d23c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 11 Jul 2025 19:47:05 +0000
Subject: [PATCH 10/10] address comments
---
.../Dialect/Vector/Transforms/VectorDistribute.cpp | 12 ++++--------
1 file changed, 4 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 15c14bef37e76..e62031412eab6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,12 +17,8 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Value.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1787,11 +1783,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// index and the corresponding new `WarpOp` yield value index (needed to
// update users later).
llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
- for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
- nonForResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+ for (auto [i, v] :
+ llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
+ nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
- newWarpOpDistTypes.push_back(
- warpOp.getResult(nonForResultIndices[i]).getType());
+ newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
More information about the Mlir-commits
mailing list