[Mlir-commits] [mlir] [mlir][xegpu] Add XeGPU subgroup map propagation analysis for XeGPU SIMT distribution. (PR #130240)
Charitha Saumya
llvmlistbot at llvm.org
Fri Mar 14 09:47:34 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/130240
>From ff972397d8c1f8acf51781fd26d42ae9ddd26db2 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 28 Feb 2025 23:57:37 +0000
Subject: [PATCH 01/27] save work
---
.../mlir/Dialect/XeGPU/Transforms/Passes.td | 10 +
.../Dialect/XeGPU/Transforms/Transforms.h | 2 +
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 +
.../Transforms/XeGPUSubgroupDistribute.cpp | 249 ++++++++++++++++++
4 files changed, 262 insertions(+)
create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 1ecd6ce95322b..cb9d403566645 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -23,4 +23,14 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
];
}
+def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
+ let summary = "Distribute XeGPU ops to work items";
+ let description = [{
+ The pass distributes subgroup level (SIMD) XeGPU ops to work items.
+ }];
+ let dependentDialects = [
+ "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df06937..86b95721df60c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,6 +16,8 @@ namespace xegpu {
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+/// Appends patterns for distributing XeGPU ops to work items into `patterns`.
+void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 7fb64d3b97b87..124e904edb543 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
+ XeGPUSubgroupDistribute.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
new file mode 100644
index 0000000000000..99995d92e24b6
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -0,0 +1,249 @@
+//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-subgroup-distribute"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+constexpr unsigned subgroupSize = 16;
+constexpr unsigned packedASizeInBits = 16;
+constexpr unsigned packedBSizeInBits = 32;
+
+namespace {
+struct Layout2D {
+ SmallVector<int64_t, 2> layout;
+ Layout2D() = default;
+ Layout2D(int64_t x, int64_t y) { layout.insert(layout.end(), {x, y}); }
+ bool operator==(const Layout2D &rhs) const {
+ return this->layout == rhs.layout;
+ }
+ bool operator<(const Layout2D &rhs) const {
+ return this->layout < rhs.layout;
+ }
+ void print(llvm::raw_ostream &os) const {
+ os << "{";
+ llvm::interleave(
+ layout, os, [&](int64_t a) { os << a; }, ", ");
+ os << "}";
+ }
+};
+
+using WiLayout = Layout2D;
+using WiData = Layout2D;
+
+struct SGMapInfo {
+ Layout2D wiLayout;
+ Layout2D wiData;
+ SGMapInfo() = default;
+ SGMapInfo(const Layout2D &layout, const Layout2D &data, unsigned bitWidth)
+ : wiLayout(layout), wiData(data) {}
+ bool operator==(const SGMapInfo &rhs) const {
+ return this->wiLayout == rhs.wiLayout && this->wiData == rhs.wiData;
+ }
+ bool operator<(const SGMapInfo &rhs) const {
+ return this->wiLayout < rhs.wiLayout || this->wiData < rhs.wiData;
+ }
+ void print(llvm::raw_ostream &os) const {
+ os << "{";
+ os << "layout: ";
+ wiLayout.print(os);
+ os << ", ";
+ os << "data: ";
+ wiData.print(os);
+ os << "}";
+ }
+};
+
+struct SGMapLatticeValue {
+private:
+ std::set<SGMapInfo> layouts;
+
+public:
+ SGMapLatticeValue() = default;
+ SGMapLatticeValue(const SGMapLatticeValue &other) = default;
+ SGMapLatticeValue(const WiLayout &layout, const WiData &data) {
+ layouts.insert(SGMapInfo(layout, data, 16));
+ }
+
+ bool operator==(const SGMapLatticeValue &other) const {
+ return this->layouts == other.layouts;
+ }
+
+ /// This function depends on a partial ordering of the lattice values.
+ static SGMapLatticeValue meet(const SGMapLatticeValue &lhs,
+ const SGMapLatticeValue &rhs) {
+ SGMapLatticeValue res = lhs;
+ (void)res.addLayouts(rhs.layouts);
+ return res;
+ }
+
+ static SGMapLatticeValue join(const SGMapLatticeValue &lhs,
+ const SGMapLatticeValue &rhs) {
+ // Should not be triggered by this analysis, but required by `Lattice<T>`
+ llvm_unreachable("Join should not be triggered by this test");
+ }
+
+ ChangeResult addLayouts(const std::set<SGMapInfo> &layouts) {
+ int sizeBefore = this->layouts.size();
+ this->layouts.insert(layouts.begin(), layouts.end());
+ int sizeAfter = this->layouts.size();
+ return sizeBefore == sizeAfter ? ChangeResult::NoChange
+ : ChangeResult::Change;
+ }
+
+ void print(raw_ostream &os) const {
+ os << "[";
+ llvm::interleave(
+ layouts, os, [&](const SGMapInfo &a) { a.print(os); }, ", ");
+ os << "]";
+ }
+
+ void clear() { layouts.clear(); }
+
+ std::set<SGMapInfo> getLayouts() const { return layouts; }
+};
+
+struct SGMap : public Lattice<SGMapLatticeValue> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMap)
+ using Lattice::Lattice;
+};
+
+static SGMapLatticeValue getSGMapForDPASOperand(Type operandTy,
+ unsigned operandNum) {
+ int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
+ int packingFactorForA =
+ operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
+ ? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
+ : 1;
+ return SGMapLatticeValue(WiLayout(1, subgroupSize),
+ WiData(operandNum == 1 ? packingFactorForB : 1,
+ operandNum == 0 ? packingFactorForA : 1));
+}
+
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMap> {
+public:
+ SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+ : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+ using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+ LogicalResult visitOperation(Operation *op, ArrayRef<SGMap *> operands,
+ ArrayRef<const SGMap *> results) override;
+
+ void visitBranchOperand(OpOperand &operand) override{};
+
+ void visitCallOperand(OpOperand &operand) override{};
+
+ void visitExternalCall(CallOpInterface call, ArrayRef<SGMap *> operands,
+ ArrayRef<const SGMap *> results) override{};
+
+ void setToExitState(SGMap *lattice) override {
+ (void)lattice->meet(SGMapLatticeValue());
+ }
+};
+} // namespace
+
+LogicalResult
+SGMapPropagation::visitOperation(Operation *op, ArrayRef<SGMap *> operands,
+ ArrayRef<const SGMap *> results) {
+ /// Handle dpas
+ if (auto dpas = dyn_cast<xegpu::DpasOp>(op)) {
+ auto aTy = dpas.getLhsType().getElementType();
+ auto bTy = dpas.getRhsType().getElementType();
+ propagateIfChanged(operands[0],
+ operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
+ propagateIfChanged(operands[1],
+ operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
+ if (operands.size() > 2) {
+ auto cTy = dpas.getAccType().getElementType();
+ propagateIfChanged(operands[2],
+ operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
+ }
+ return success();
+ }
+ for (const SGMap *r : results) {
+ // For each operand assume a default layout.
+ for (SGMap *operand : operands) {
+ meet(operand, *r);
+ }
+ addDependency(const_cast<SGMap *>(r), getProgramPointAfter(op));
+ }
+ return success();
+}
+
+void xegpu::populateXeGPUSubgroupDistributePatterns(
+ RewritePatternSet &patterns) {}
+
+namespace {
+
+class RunSGMapPropagation {
+public:
+ RunSGMapPropagation(Operation *op) {
+ SymbolTableCollection symbolTable;
+
+ solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
+ solver.load<SGMapPropagation>(symbolTable);
+ (void)solver.initializeAndRun(op);
+ }
+
+ SGMapLatticeValue getSGMap(Value val) {
+ auto *state = solver.lookupState<SGMap>(val);
+ if (!state)
+ return {};
+ return state->getValue();
+ }
+
+private:
+ DataFlowSolver solver;
+};
+
+struct XeGPUSubgroupDistributePass final
+ : public xegpu::impl::XeGPUSubgroupDistributeBase<
+ XeGPUSubgroupDistributePass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUSubgroupDistributePass::runOnOperation() {
+ Operation *op = getOperation();
+
+ RunSGMapPropagation solver(op);
+
+ // Print analysis results
+ auto &os = llvm::outs();
+ op->walk([&](Operation *op) {
+ if (op->getResults().empty())
+ return;
+ auto layouts = solver.getSGMap(op->getResult(0));
+ os << "SGMap for " << op->getName() << ": ";
+ layouts.print(os);
+ os << "\n";
+ });
+}
>From d0c1144513b0c6c077c0b545889d8e0d90cf394d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 1 Mar 2025 00:04:34 +0000
Subject: [PATCH 02/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 33 ++++++++++++++-----
1 file changed, 25 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 99995d92e24b6..d2f43b3448f95 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -9,11 +9,13 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
@@ -238,12 +240,27 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Print analysis results
auto &os = llvm::outs();
- op->walk([&](Operation *op) {
- if (op->getResults().empty())
- return;
- auto layouts = solver.getSGMap(op->getResult(0));
- os << "SGMap for " << op->getName() << ": ";
- layouts.print(os);
- os << "\n";
- });
+ // check if op is a function
+ // llvm::errs() << op->getName() << "\n";
+ if (auto modOp = dyn_cast<ModuleOp>(op)) {
+ for (auto funcOp : modOp.getOps<func::FuncOp>()) {
+ os << "SGMap for " << funcOp.getName() << ":\n";
+ // Function args
+ for (auto arg : funcOp.getArguments()) {
+ auto layouts = solver.getSGMap(arg);
+ os << "SGMap for " << arg << ": ";
+ layouts.print(os);
+ os << "\n";
+ }
+ // Function ops
+ funcOp.walk([&](Operation *op) {
+ if (op->getResults().empty())
+ return;
+ auto layouts = solver.getSGMap(op->getResult(0));
+ os << "SGMap for " << op->getName() << ": ";
+ layouts.print(os);
+ os << "\n";
+ });
+ }
+ }
}
>From 0b8752c921b17a99ab58f48cebc7564cd1387256 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 4 Mar 2025 18:14:40 +0000
Subject: [PATCH 03/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 361 +++++++++++-------
.../XeGPU/subgroup-map-propagation.mlir | 186 +++++++++
2 files changed, 402 insertions(+), 145 deletions(-)
create mode 100644 mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d2f43b3448f95..c80ba2ca5f3d1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -11,12 +11,15 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
@@ -38,217 +41,264 @@ constexpr unsigned packedASizeInBits = 16;
constexpr unsigned packedBSizeInBits = 32;
namespace {
-struct Layout2D {
- SmallVector<int64_t, 2> layout;
- Layout2D() = default;
- Layout2D(int64_t x, int64_t y) { layout.insert(layout.end(), {x, y}); }
- bool operator==(const Layout2D &rhs) const {
- return this->layout == rhs.layout;
- }
- bool operator<(const Layout2D &rhs) const {
- return this->layout < rhs.layout;
- }
- void print(llvm::raw_ostream &os) const {
- os << "{";
- llvm::interleave(
- layout, os, [&](int64_t a) { os << a; }, ", ");
- os << "}";
- }
+struct Layout {
+ SmallVector<int64_t, 3> layout;
+ Layout() = default;
+ Layout(const Layout &other) = default;
+ Layout(std::initializer_list<int64_t> list) : layout(list) {}
+ void print(llvm::raw_ostream &os) const;
+ size_t size() const { return layout.size(); }
};
-using WiLayout = Layout2D;
-using WiData = Layout2D;
-
-struct SGMapInfo {
- Layout2D wiLayout;
- Layout2D wiData;
- SGMapInfo() = default;
- SGMapInfo(const Layout2D &layout, const Layout2D &data, unsigned bitWidth)
- : wiLayout(layout), wiData(data) {}
- bool operator==(const SGMapInfo &rhs) const {
- return this->wiLayout == rhs.wiLayout && this->wiData == rhs.wiData;
- }
- bool operator<(const SGMapInfo &rhs) const {
- return this->wiLayout < rhs.wiLayout || this->wiData < rhs.wiData;
- }
- void print(llvm::raw_ostream &os) const {
- os << "{";
- os << "layout: ";
- wiLayout.print(os);
- os << ", ";
- os << "data: ";
- wiData.print(os);
- os << "}";
- }
-};
+void Layout::print(llvm::raw_ostream &os) const {
+ os << "[";
+ llvm::interleaveComma(layout, os);
+ os << "]";
+}
+
+using WiLayout = Layout;
+using WiData = Layout;
-struct SGMapLatticeValue {
+struct SGMap {
private:
- std::set<SGMapInfo> layouts;
+ WiLayout layout;
+ WiData data;
public:
- SGMapLatticeValue() = default;
- SGMapLatticeValue(const SGMapLatticeValue &other) = default;
- SGMapLatticeValue(const WiLayout &layout, const WiData &data) {
- layouts.insert(SGMapInfo(layout, data, 16));
+ SGMap() = default;
+ SGMap(const SGMap &other) = default;
+ SGMap(const WiLayout &layout, const WiData &data)
+ : layout(layout), data(data) {}
+
+ // Two lattice values are equal if they have `some` layout. The actual
+ // content of the layout does not matter.
+ bool operator==(const SGMap &other) const {
+ return this->isAssigned() == other.isAssigned();
}
- bool operator==(const SGMapLatticeValue &other) const {
- return this->layouts == other.layouts;
- }
+ static SGMap meet(const SGMap &lhs, const SGMap &rhs);
- /// This function depends on a partial ordering of the lattice values.
- static SGMapLatticeValue meet(const SGMapLatticeValue &lhs,
- const SGMapLatticeValue &rhs) {
- SGMapLatticeValue res = lhs;
- (void)res.addLayouts(rhs.layouts);
- return res;
- }
+ static SGMap join(const SGMap &lhs, const SGMap &rhs);
- static SGMapLatticeValue join(const SGMapLatticeValue &lhs,
- const SGMapLatticeValue &rhs) {
- // Should not be triggered by this analysis, but required by `Lattice<T>`
- llvm_unreachable("Join should not be triggered by this test");
- }
+ void print(raw_ostream &os) const;
- ChangeResult addLayouts(const std::set<SGMapInfo> &layouts) {
- int sizeBefore = this->layouts.size();
- this->layouts.insert(layouts.begin(), layouts.end());
- int sizeAfter = this->layouts.size();
- return sizeBefore == sizeAfter ? ChangeResult::NoChange
- : ChangeResult::Change;
- }
+ bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
- void print(raw_ostream &os) const {
- os << "[";
- llvm::interleave(
- layouts, os, [&](const SGMapInfo &a) { a.print(os); }, ", ");
- os << "]";
- }
+ SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+};
+
+void SGMap::print(raw_ostream &os) const {
+ if (isAssigned()) {
+ os << "Layout: ";
+ layout.print(os);
+ os << ", Data: ";
+ data.print(os);
+ } else
+ os << "Not initialized";
+}
- void clear() { layouts.clear(); }
+SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
+ if (!lhs.isAssigned())
+ return rhs;
+ return lhs;
+}
- std::set<SGMapInfo> getLayouts() const { return layouts; }
-};
+SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
+ // Should not be triggered by this analysis, but required by `Lattice<T>`
+ llvm_unreachable("Join should not be triggered by this test");
+}
-struct SGMap : public Lattice<SGMapLatticeValue> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMap)
+SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+ if (!isAssigned())
+ return {};
+ WiLayout newLayout;
+ WiData newData;
+ for (auto idx : permutation) {
+ newLayout.layout.push_back(layout.layout[idx]);
+ newData.layout.push_back(data.layout[idx]);
+ }
+ return SGMap(newLayout, data);
+}
+
+struct SGMapLattice : public Lattice<SGMap> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
using Lattice::Lattice;
};
-static SGMapLatticeValue getSGMapForDPASOperand(Type operandTy,
- unsigned operandNum) {
+/// Helper Functions
+///
+
+static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
int packingFactorForA =
operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
: 1;
- return SGMapLatticeValue(WiLayout(1, subgroupSize),
- WiData(operandNum == 1 ? packingFactorForB : 1,
- operandNum == 0 ? packingFactorForA : 1));
+ return SGMap(WiLayout({1, subgroupSize}),
+ WiData({operandNum == 1 ? packingFactorForB : 1,
+ operandNum == 0 ? packingFactorForA : 1}));
+}
+
+static SGMap getDefaultSgMap(Type ty) {
+ int packingFactor = 1;
+ if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
+ packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
+ return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
}
-class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMap> {
+class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
+private:
+ void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
+ void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
+ void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
+ void visitTransposeOp(vector::TransposeOp transpose,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
public:
SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
- LogicalResult visitOperation(Operation *op, ArrayRef<SGMap *> operands,
- ArrayRef<const SGMap *> results) override;
+ LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) override;
void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
- void visitExternalCall(CallOpInterface call, ArrayRef<SGMap *> operands,
- ArrayRef<const SGMap *> results) override{};
+ void visitExternalCall(CallOpInterface call,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) override{};
- void setToExitState(SGMap *lattice) override {
- (void)lattice->meet(SGMapLatticeValue());
+ void setToExitState(SGMapLattice *lattice) override {
+ (void)lattice->meet(SGMap());
}
};
} // namespace
LogicalResult
-SGMapPropagation::visitOperation(Operation *op, ArrayRef<SGMap *> operands,
- ArrayRef<const SGMap *> results) {
- /// Handle dpas
- if (auto dpas = dyn_cast<xegpu::DpasOp>(op)) {
- auto aTy = dpas.getLhsType().getElementType();
- auto bTy = dpas.getRhsType().getElementType();
- propagateIfChanged(operands[0],
- operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
- propagateIfChanged(operands[1],
- operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
- if (operands.size() > 2) {
- auto cTy = dpas.getAccType().getElementType();
- propagateIfChanged(operands[2],
- operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
- }
- return success();
- }
- for (const SGMap *r : results) {
- // For each operand assume a default layout.
- for (SGMap *operand : operands) {
- meet(operand, *r);
+SGMapPropagation::visitOperation(Operation *op,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ if (auto dpas = dyn_cast<xegpu::DpasOp>(op))
+ visitDpasOp(dpas, operands, results);
+ else if (auto store = dyn_cast<xegpu::StoreNdOp>(op))
+ visitStoreNdOp(store, operands, results);
+ else if (auto load = dyn_cast<xegpu::LoadNdOp>(op))
+ visitLoadNdOp(load, operands, results);
+ else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
+ visitTransposeOp(transpose, operands, results);
+ /// All other ops
+ else {
+ for (const SGMapLattice *r : results) {
+ for (SGMapLattice *operand : operands) {
+ if (r->getValue().isAssigned())
+ meet(operand, *r);
+ }
+ addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
}
- addDependency(const_cast<SGMap *>(r), getProgramPointAfter(op));
}
return success();
}
-void xegpu::populateXeGPUSubgroupDistributePatterns(
- RewritePatternSet &patterns) {}
+void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto aTy = dpas.getLhsType().getElementType();
+ auto bTy = dpas.getRhsType().getElementType();
+ propagateIfChanged(operands[0],
+ operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
+ propagateIfChanged(operands[1],
+ operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
+ if (operands.size() > 2) {
+ auto cTy = dpas.getAccType().getElementType();
+ propagateIfChanged(operands[2],
+ operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
+ }
+};
+
+void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto storeLayout =
+ getDefaultSgMap(store.getTensorDescType().getElementType());
+ /// Both operands should have the same layout
+ for (SGMapLattice *operand : operands) {
+ propagateIfChanged(operand, operand->meet(storeLayout));
+ }
+}
+
+void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto valueLayout = results[0]->getValue();
+ /// Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+ SGMap tensorDescLayout = valueLayout;
+ if (auto transpose = load.getTranspose())
+ tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+ /// Propagate the new layout to the tensor descriptor operand.
+ propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+}
+
+void SGMapPropagation::visitTransposeOp(
+ vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ /// Need the layout of transpose result to propagate to the operands.
+ auto operandLayout = results[0]->getValue();
+ if (!operandLayout.isAssigned())
+ return;
+ auto newLayout =
+ operandLayout.getTransposedLayout(transpose.getPermutation());
+ /// Propagate the new layout to the vector operand.
+ propagateIfChanged(operands[0], operands[0]->meet(newLayout));
+}
namespace {
class RunSGMapPropagation {
public:
- RunSGMapPropagation(Operation *op) {
+ RunSGMapPropagation(Operation *op) : target(op) {
SymbolTableCollection symbolTable;
-
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<SGMapPropagation>(symbolTable);
(void)solver.initializeAndRun(op);
}
- SGMapLatticeValue getSGMap(Value val) {
- auto *state = solver.lookupState<SGMap>(val);
- if (!state)
- return {};
- return state->getValue();
- }
+ SGMap getSGMap(Value val);
+
+ void printAnalysisResult(llvm::raw_ostream &os);
private:
DataFlowSolver solver;
+ const Operation *target;
};
-
-struct XeGPUSubgroupDistributePass final
- : public xegpu::impl::XeGPUSubgroupDistributeBase<
- XeGPUSubgroupDistributePass> {
- void runOnOperation() override;
-};
-
} // namespace
-void XeGPUSubgroupDistributePass::runOnOperation() {
- Operation *op = getOperation();
-
- RunSGMapPropagation solver(op);
+SGMap RunSGMapPropagation::getSGMap(Value val) {
+ auto *state = solver.lookupState<SGMapLattice>(val);
+ if (!state)
+ return {};
+ return state->getValue();
+}
- // Print analysis results
- auto &os = llvm::outs();
- // check if op is a function
- // llvm::errs() << op->getName() << "\n";
- if (auto modOp = dyn_cast<ModuleOp>(op)) {
+void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
+ if (auto modOp = dyn_cast<ModuleOp>(target)) {
for (auto funcOp : modOp.getOps<func::FuncOp>()) {
- os << "SGMap for " << funcOp.getName() << ":\n";
+ os << "sg_map for " << funcOp.getName() << ":\n";
// Function args
for (auto arg : funcOp.getArguments()) {
- auto layouts = solver.getSGMap(arg);
- os << "SGMap for " << arg << ": ";
+ auto layouts = getSGMap(arg);
+ os << "sg_map for " << arg << ": ";
layouts.print(os);
os << "\n";
}
@@ -256,11 +306,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
funcOp.walk([&](Operation *op) {
if (op->getResults().empty())
return;
- auto layouts = solver.getSGMap(op->getResult(0));
- os << "SGMap for " << op->getName() << ": ";
+ auto layouts = getSGMap(op->getResult(0));
+ os << "sg_map for " << op->getName() << ": ";
layouts.print(os);
os << "\n";
});
}
}
}
+
+namespace {
+struct XeGPUSubgroupDistributePass final
+ : public xegpu::impl::XeGPUSubgroupDistributeBase<
+ XeGPUSubgroupDistributePass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void XeGPUSubgroupDistributePass::runOnOperation() {
+ Operation *op = getOperation();
+
+ RunSGMapPropagation solver(op);
+
+ // Print analysis results
+ auto &os = llvm::outs();
+ solver.printAnalysisResult(os);
+}
+
+void xegpu::populateXeGPUSubgroupDistributePatterns(
+ RewritePatternSet &patterns) {}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
new file mode 100644
index 0000000000000..217d510973a4f
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -0,0 +1,186 @@
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0> : vector<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+// -----
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0> : vector<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 {transpose = array<i64: 1, 0>} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
+
+// -----
+func.func @test(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>) -> vector<8x16xf32> {
+ %0 = xegpu.dpas %arg0, %arg1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x32xi8>) -> vector<8x32xi32> {
+ %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x32xi8> -> vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<8x16xf32>) {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %3 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %2, %3 : vector<8x16xf32>, vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
+ %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
+ %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %3 = arith.addf %1, %2 : vector<16x16xf16>
+ %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+ %2 = arith.addf %1, %cst : vector<16x16xf16>
+ %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %3 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %3 = arith.addf %1, %2 : vector<16x16xf16>
+ %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %4, %2 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: index) -> vector<8x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %3:2 = scf.for %arg3 = %c0 to %arg2 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (vector<16x16xf16>, vector<8x16xf32>) {
+ %4 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %5 = arith.addf %arg4, %4 : vector<16x16xf16>
+ %6 = xegpu.dpas %0, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ scf.yield %5, %6 : vector<16x16xf16>, vector<8x16xf32>
+ }
+ return %3#1 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+ %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ scf.yield %3 : vector<16x16xf16>
+ } else {
+ %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ scf.yield %3 : vector<16x16xf16>
+ }
+ %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> (vector<8x16xf32>, vector<16x16xf16>) {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+ %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ scf.yield %3 : vector<16x16xf16>
+ } else {
+ %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ scf.yield %3 : vector<16x16xf16>
+ }
+ %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = scf.if %arg3 -> (vector<16x16xf16>) {
+ %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ scf.yield %3 : vector<16x16xf16>
+ } else {
+ scf.yield %arg2 : vector<16x16xf16>
+ }
+ %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %2 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+ %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
+ %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
+ %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %4 : vector<8x16xf32>
+}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %1 = scf.if %arg3 -> (vector<8x16xf32>) {
+ %2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %3 = arith.addf %2, %arg2 : vector<16x16xf16>
+ %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ scf.yield %4 : vector<8x16xf32>
+ } else {
+ %2 = xegpu.dpas %0, %arg2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ scf.yield %2 : vector<8x16xf32>
+ }
+ return %1 : vector<8x16xf32>
+}
>From 2890d5031eab8c07e8b1e3526b1caa13f2574e0d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 4 Mar 2025 22:58:28 +0000
Subject: [PATCH 04/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 51 ++++++++++++-
.../XeGPU/subgroup-map-propagation.mlir | 75 +++++++++++--------
2 files changed, 95 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c80ba2ca5f3d1..4f6091d7bebfb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -48,6 +48,7 @@ struct Layout {
Layout(std::initializer_list<int64_t> list) : layout(list) {}
void print(llvm::raw_ostream &os) const;
size_t size() const { return layout.size(); }
+ int64_t operator[](size_t idx) const { return layout[idx]; }
};
void Layout::print(llvm::raw_ostream &os) const {
@@ -85,6 +86,9 @@ struct SGMap {
bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
+
+ const WiLayout &getLayout() const { return layout; }
+ const WiData &getData() const { return data; }
};
void SGMap::print(raw_ostream &os) const {
@@ -160,6 +164,9 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
void visitTransposeOp(vector::TransposeOp transpose,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
+ void visitVectorBitcastOp(vector::BitCastOp bitcast,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
public:
SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
@@ -195,16 +202,23 @@ SGMapPropagation::visitOperation(Operation *op,
visitLoadNdOp(load, operands, results);
else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
visitTransposeOp(transpose, operands, results);
+ else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
+ visitVectorBitcastOp(bitcast, operands, results);
/// All other ops
else {
for (const SGMapLattice *r : results) {
for (SGMapLattice *operand : operands) {
+ /// Propagate the layout of the result to the operand.
if (r->getValue().isAssigned())
meet(operand, *r);
}
- addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
}
}
+ /// Add a dependency from each reult to program point after the operation.
+ /// NOTE: not sure if this is required, but all other passes do this.
+ for (const SGMapLattice *r : results) {
+ addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
+ }
return success();
}
@@ -262,6 +276,41 @@ void SGMapPropagation::visitTransposeOp(
propagateIfChanged(operands[0], operands[0]->meet(newLayout));
}
+void SGMapPropagation::visitVectorBitcastOp(
+ vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ /// Need the layout of bitcast result to propagate to the operands.
+ auto resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ auto inElemTyBitWidth =
+ bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+ auto outElemTyBitWidth =
+ bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+
+ /// WiLayout does not change.
+ WiLayout newWiLayout = resultLayout.getLayout();
+ WiData newWiData;
+ /// It's a widening bitcast
+ if (inElemTyBitWidth < outElemTyBitWidth) {
+ auto ratio = outElemTyBitWidth / inElemTyBitWidth;
+ const auto &currData = resultLayout.getData();
+ newWiData = resultLayout.getData()[0] == 1
+ ? WiData({1, currData[1] * ratio})
+ : WiData({currData[0] * ratio, 1});
+ } else {
+ /// It's a narrowing bitcast
+ auto ratio = inElemTyBitWidth / outElemTyBitWidth;
+ const auto &currData = resultLayout.getData();
+ newWiData = resultLayout.getData()[0] == 1
+ ? WiData({1, currData[1] / ratio})
+ : WiData({currData[0] / ratio, 1});
+ }
+
+ propagateIfChanged(operands[0],
+ operands[0]->meet(SGMap(newWiLayout, newWiData)));
+}
+
namespace {
class RunSGMapPropagation {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 217d510973a4f..a806dfd9f0a6c 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -25,52 +25,67 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
return
}
-
// -----
-func.func @test(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>) -> vector<8x16xf32> {
- %0 = xegpu.dpas %arg0, %arg1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %0 : vector<8x16xf32>
+func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0> : vector<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %6 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+ %4 = xegpu.dpas %2, %6, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
}
// -----
-func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x32xi8>) -> vector<8x32xi32> {
- %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x32xi8> -> vector<8x32xi32>
- return %0 : vector<8x32xi32>
+func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+ xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+ return
}
// -----
func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %2 : vector<8x16xf32>
-}
-
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<8x16xf32>) {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- %3 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %2, %3 : vector<8x16xf32>, vector<8x16xf32>
+ %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
+ %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
+ %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %4 : vector<8x16xf32>
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+ %c0 = arith.constant 0 : index
+ %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+ %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+ %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+ %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+ %6 = vector.bitcast %4 : vector<8x16xi16> to vector<8x32xi8>
+ %0 = xegpu.dpas %6, %5 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+ xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+ return
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
- %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16>
- %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %4 : vector<8x16xf32>
+func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+ %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+ %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+ %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+ %6 = vector.bitcast %4 : vector<8x32xi8> to vector<8x16xf16>
+ %7 = vector.bitcast %5 : vector<16x32xi8> to vector<16x16xf16>
+ %0 = xegpu.dpas %6, %7 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %0, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
}
// -----
>From 38a388e013ff0f98fa1d24f11c6bf24b6baa649a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 5 Mar 2025 20:40:20 +0000
Subject: [PATCH 05/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 76 ++++++++++++++++++-
.../XeGPU/subgroup-map-propagation.mlir | 15 ++++
2 files changed, 87 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4f6091d7bebfb..eaf87087c6fe2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -161,13 +161,26 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
+ void visitLoadGatherOp(xegpu::LoadGatherOp load,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
void visitTransposeOp(vector::TransposeOp transpose,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
+
void visitVectorBitcastOp(vector::BitCastOp bitcast,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
+ void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
+ void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
public:
SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -204,6 +217,12 @@ SGMapPropagation::visitOperation(Operation *op,
visitTransposeOp(transpose, operands, results);
else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
visitVectorBitcastOp(bitcast, operands, results);
+ else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
+ visitLoadGatherOp(loadGather, operands, results);
+ else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+ visitCreateNdDescOp(createNdDesc, operands, results);
+ else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
+ visitCreateDescOp(createDesc, operands, results);
/// All other ops
else {
for (const SGMapLattice *r : results) {
@@ -215,7 +234,8 @@ SGMapPropagation::visitOperation(Operation *op,
}
}
/// Add a dependency from each reult to program point after the operation.
- /// NOTE: not sure if this is required, but all other passes do this.
+ /// NOTE: not sure if this is required, but all other similar analysis do
+ /// this.
for (const SGMapLattice *r : results) {
addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
}
@@ -289,19 +309,18 @@ void SGMapPropagation::visitVectorBitcastOp(
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
/// WiLayout does not change.
- WiLayout newWiLayout = resultLayout.getLayout();
+ const WiLayout &newWiLayout = resultLayout.getLayout();
+ const WiData &currData = resultLayout.getData();
WiData newWiData;
/// It's a widening bitcast
if (inElemTyBitWidth < outElemTyBitWidth) {
auto ratio = outElemTyBitWidth / inElemTyBitWidth;
- const auto &currData = resultLayout.getData();
newWiData = resultLayout.getData()[0] == 1
? WiData({1, currData[1] * ratio})
: WiData({currData[0] * ratio, 1});
} else {
/// It's a narrowing bitcast
auto ratio = inElemTyBitWidth / outElemTyBitWidth;
- const auto &currData = resultLayout.getData();
newWiData = resultLayout.getData()[0] == 1
? WiData({1, currData[1] / ratio})
: WiData({currData[0] / ratio, 1});
@@ -311,6 +330,55 @@ void SGMapPropagation::visitVectorBitcastOp(
operands[0]->meet(SGMap(newWiLayout, newWiData)));
}
+void SGMapPropagation::visitLoadGatherOp(
+ xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto valueLayout = results[0]->getValue();
+ /// Need the layout of the value to propagate to the tensor descriptor.
+ if (!valueLayout.isAssigned())
+ return;
+ /// LoadGatherOp has the transpose effect, so propagate the transposed layout
+ /// to the tensor descriptor.
+ SGMap tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+ /// Mask operand should have the same layout as the value but with wi_data =
+ /// [1, 1]
+ SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+ /// Propagate the new layout to the tensor descriptor operand.
+ propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+ /// Propagate the new layout to the mask operand.
+ propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+}
+
+void SGMapPropagation::visitCreateNdDescOp(
+ xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto descLayout = results[0]->getValue();
+ /// Need the layout of the descriptor to propagate to the operands.
+ if (!descLayout.isAssigned())
+ return;
+ /// Propagate the layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+ /// For all other operands propagate the same layout with wi_data = [1, 1]
+ SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
+ for (size_t i = 1; i < operands.size(); ++i) {
+ propagateIfChanged(operands[i], operands[i]->meet(layout));
+ }
+}
+
+void SGMapPropagation::visitCreateDescOp(
+ xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto descLayout = results[0]->getValue();
+ /// Need the layout of the descriptor to propagate to the operands.
+ if (!descLayout.isAssigned())
+ return;
+ /// Propagate the layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+ /// For offset operand propagate the same layout with wi_data = [1, 1]
+ SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
+ propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
namespace {
class RunSGMapPropagation {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index a806dfd9f0a6c..dfe9c75c36ed6 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -59,6 +59,21 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
return %4 : vector<8x16xf32>
}
+// -----
+func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %2 = xegpu.create_tdesc %b, %0 : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>
+ %3 = xegpu.load %2, %1 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>, vector<16xi1> -> vector<16x16xf16>
+ %4 = xegpu.dpas %7, %3 : vector<8x16xf16>, vector<16x16xf16>-> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
// -----
func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
%c0 = arith.constant 0 : index
>From 8996f11f39931c9540572a630f9e726d810cd5fa Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 5 Mar 2025 22:07:34 +0000
Subject: [PATCH 06/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 56 ++++++++++++++++---
.../XeGPU/subgroup-map-propagation.mlir | 36 ++++++++----
2 files changed, 72 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index eaf87087c6fe2..871138b8dbc16 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -158,6 +158,10 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
+ void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
@@ -223,6 +227,8 @@ SGMapPropagation::visitOperation(Operation *op,
visitCreateNdDescOp(createNdDesc, operands, results);
else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
visitCreateDescOp(createDesc, operands, results);
+ else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
+ visitStoreScatterOp(storeScatter, operands, results);
/// All other ops
else {
for (const SGMapLattice *r : results) {
@@ -379,6 +385,22 @@ void SGMapPropagation::visitCreateDescOp(
propagateIfChanged(operands[1], operands[1]->meet(layout));
}
+void SGMapPropagation::visitStoreScatterOp(
+ xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ /// StoreScatterOp has the transpose effect. Value has the regular layout,
+ /// while the tensor descriptor has the transposed layout.
+ auto valueLayout =
+ getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
+ auto storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+ /// Propagate the value layout.
+ propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+ /// Propagate the tensor descriptor layout.
+ propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+ /// Propagate the mask layout.
+ propagateIfChanged(operands[2], operands[2]->meet(valueLayout));
+}
+
namespace {
class RunSGMapPropagation {
@@ -411,20 +433,25 @@ SGMap RunSGMapPropagation::getSGMap(Value val) {
void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
if (auto modOp = dyn_cast<ModuleOp>(target)) {
for (auto funcOp : modOp.getOps<func::FuncOp>()) {
- os << "sg_map for " << funcOp.getName() << ":\n";
- // Function args
+ os << "function: " << funcOp.getName() << ":\n";
+ // Function arguments
for (auto arg : funcOp.getArguments()) {
- auto layouts = getSGMap(arg);
- os << "sg_map for " << arg << ": ";
- layouts.print(os);
+ auto layout = getSGMap(arg);
+ os << "argument: " << arg << "\n";
+ os << "sg_map : ";
+ layout.print(os);
os << "\n";
}
// Function ops
funcOp.walk([&](Operation *op) {
+ // Skip ops that do not have results
if (op->getResults().empty())
return;
auto layouts = getSGMap(op->getResult(0));
- os << "sg_map for " << op->getName() << ": ";
+ os << "op : ";
+ op->print(os);
+ os << "\n";
+ os << "sg_map: ";
layouts.print(os);
os << "\n";
});
@@ -436,7 +463,18 @@ namespace {
struct XeGPUSubgroupDistributePass final
: public xegpu::impl::XeGPUSubgroupDistributeBase<
XeGPUSubgroupDistributePass> {
+ XeGPUSubgroupDistributePass() = default;
+ XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other)
+ : xegpu::impl::XeGPUSubgroupDistributeBase<XeGPUSubgroupDistributePass>(
+ other) {
+ this->printOnly = other.printOnly;
+ }
void runOnOperation() override;
+ /// Print sg map propagation analysis result and exit for testing purposes.
+ Option<bool> printOnly{
+ *this, "print-only", llvm::cl::init(false),
+ llvm::cl::desc(
+ "Print the result of the subgroup map propagation analysis")};
};
} // namespace
@@ -446,8 +484,10 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
RunSGMapPropagation solver(op);
// Print analysis results
- auto &os = llvm::outs();
- solver.printAnalysisResult(os);
+ if (printOnly) {
+ auto &os = llvm::outs();
+ solver.printAnalysisResult(os);
+ }
}
void xegpu::populateXeGPUSubgroupDistributePatterns(
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index dfe9c75c36ed6..76a44fe16440b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,4 +1,4 @@
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -12,7 +12,16 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
}
// -----
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+ xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+ return
+}
+
+// -----
+func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -26,7 +35,7 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
}
// -----
-func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -40,17 +49,10 @@ func.func @test(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf
return
}
-// -----
-func.func @test(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
- %c0 = arith.constant 0 : index
- %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
- %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
- xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
- return
-}
+
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -74,6 +76,16 @@ func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : mem
return
}
+// -----
+func.func @test_store_scatter(%c: memref<128xf32>){
+ %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+ %1 = arith.constant dense<1> : vector<16xi1>
+ %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %3 = xegpu.create_tdesc %c, %2 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+ xegpu.store %cst, %3, %1 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
+ return
+}
+
// -----
func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
%c0 = arith.constant 0 : index
>From 47888451211e237cb1f0c2b443552921a73805cc Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 6 Mar 2025 05:20:50 +0000
Subject: [PATCH 07/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 23 +++--
.../XeGPU/subgroup-map-propagation.mlir | 97 ++++++++++---------
2 files changed, 64 insertions(+), 56 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 871138b8dbc16..4182f2176f04b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -93,12 +93,12 @@ struct SGMap {
void SGMap::print(raw_ostream &os) const {
if (isAssigned()) {
- os << "Layout: ";
+ os << "wi_layout: ";
layout.print(os);
- os << ", Data: ";
+ os << ", wi_data: ";
data.print(os);
} else
- os << "Not initialized";
+ os << "Not assigned.";
}
SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
@@ -447,13 +447,20 @@ void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
// Skip ops that do not have results
if (op->getResults().empty())
return;
- auto layouts = getSGMap(op->getResult(0));
os << "op : ";
- op->print(os);
- os << "\n";
- os << "sg_map: ";
- layouts.print(os);
+ /// For control-flow ops, print the op name only.
+ if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
+ os << op->getName();
+ else
+ op->print(os);
os << "\n";
+ /// Print the sg_map for each result.
+ for (auto [i, r] : llvm::enumerate(op->getResults())) {
+ auto layout = getSGMap(r);
+ os << "sg_map for result #" << i << ": ";
+ layout.print(os);
+ os << "\n";
+ }
});
}
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 76a44fe16440b..81c4244091e87 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,4 +1,4 @@
-func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_dpas_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -12,7 +12,7 @@ func.func @test_dpas_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref
}
// -----
-func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
%1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -21,7 +21,7 @@ func.func @test_dpas_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : mem
}
// -----
-func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -35,7 +35,7 @@ func.func @test_transpose_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : m
}
// -----
-func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_transpose_op_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.0> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -52,7 +52,7 @@ func.func @test_transpose_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : m
// -----
-func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -62,7 +62,7 @@ func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
}
// -----
-func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -77,7 +77,7 @@ func.func @test_load_gather(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : mem
}
// -----
-func.func @test_store_scatter(%c: memref<128xf32>){
+func.func @test_store_scatter_op(%c: memref<128xf32>){
%cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
%1 = arith.constant dense<1> : vector<16xi1>
%2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -87,7 +87,7 @@ func.func @test_store_scatter(%c: memref<128xf32>){
}
// -----
-func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
%3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -101,7 +101,7 @@ func.func @test_bitcast(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : m
}
// -----
-func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
%3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -116,53 +116,51 @@ func.func @test_bitcast(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : me
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%3 = arith.addf %1, %2 : vector<16x16xf16>
%4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %4 : vector<8x16xf32>
+ xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
%2 = arith.addf %1, %cst : vector<16x16xf16>
%3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %3 : vector<8x16xf32>
-}
-
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> (vector<8x16xf32>, vector<16x16xf16>) {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %3 = arith.addf %1, %2 : vector<16x16xf16>
- %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %4, %2 : vector<8x16xf32>, vector<16x16xf16>
+ xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: index) -> vector<8x16xf32> {
+func.func @test_for_op(%a: memref<8x128xf16>, %b : memref<128x16xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- %3:2 = scf.for %arg3 = %c0 to %arg2 step %c1 iter_args(%arg4 = %1, %arg5 = %2) -> (vector<16x16xf16>, vector<8x16xf32>) {
- %4 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %5 = arith.addf %arg4, %4 : vector<16x16xf16>
- %6 = xegpu.dpas %0, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- scf.yield %5, %6 : vector<16x16xf16>, vector<8x16xf32>
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = arith.constant dense<0.0> : vector<8x16xf32>
+ %3:3 = scf.for %k = %c0 to %c128 step %c16 iter_args(%arg0 = %0, %arg1 = %1, %arg2 = %2) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
+ %4 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %5 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %6 = xegpu.dpas %4, %5, %arg2 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %7 = xegpu.update_nd_offset %arg0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
+ %8 = xegpu.update_nd_offset %arg1, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
+ scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
}
- return %3#1 : vector<8x16xf32>
+ %9 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %3#2, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> vector<8x16xf32> {
+func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>){
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -172,11 +170,12 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
scf.yield %3 : vector<16x16xf16>
}
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %2 : vector<8x16xf32>
+ xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1) -> (vector<8x16xf32>, vector<16x16xf16>) {
+func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>){
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -186,7 +185,9 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
scf.yield %3 : vector<16x16xf16>
}
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %2, %1 : vector<8x16xf32>, vector<16x16xf16>
+ xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
}
// -----
@@ -202,16 +203,6 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
return %2 : vector<8x16xf32>
}
-// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
- %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
- %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
- %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
- %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %4 : vector<8x16xf32>
-}
-
// -----
func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -226,3 +217,13 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
}
return %1 : vector<8x16xf32>
}
+
+// -----
+func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
+ %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+ %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
+ %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
+ %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
+ %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %4 : vector<8x16xf32>
+}
>From c7b60d156c6a91e10fe84146f56ae2efa8a344b5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 6 Mar 2025 22:29:27 +0000
Subject: [PATCH 08/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 169 +++++++++++-------
.../XeGPU/subgroup-map-propagation.mlir | 61 ++++---
2 files changed, 142 insertions(+), 88 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4182f2176f04b..13030d5c74ea0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -15,13 +15,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/Interfaces/DataLayoutInterfaces.h"
-#include "mlir/Support/LLVM.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
-#include <algorithm>
namespace mlir {
namespace xegpu {
@@ -41,6 +35,13 @@ constexpr unsigned packedASizeInBits = 16;
constexpr unsigned packedBSizeInBits = 32;
namespace {
+
+///===----------------------------------------------------------------------===///
+/// Layout
+///===----------------------------------------------------------------------===///
+
+/// Helper class to store the ND layout of work items within a subgroup and data
+/// owned by each work item.
struct Layout {
SmallVector<int64_t, 3> layout;
Layout() = default;
@@ -57,9 +58,30 @@ void Layout::print(llvm::raw_ostream &os) const {
os << "]";
}
+/// WiLayout represents the layout of work items within a subgroup when it
+/// accesses some value. WiData represents the layout of data owned by each work
+/// item.
using WiLayout = Layout;
using WiData = Layout;
+///===----------------------------------------------------------------------===///
+/// SGMap
+///===----------------------------------------------------------------------===///
+
+/// Helper class for tracking the analysis state of a value. For SGPropagation,
+/// the analysis state is simply the wi_layout and wi_data of each value.
+/// Purpose of this analysis to propagate some unique layout for each value in
+/// the program starting from some known values (like DPAS, StoreNd, etc.).
+///
+/// Given this, SGMap satisifies the following properties:
+/// 1) SGMap is a lattice with two states - assigned and not assigned.
+/// 2) Two SGMap values are equal if they are both assigned or both not
+/// assigned. The concrete value of assigned state does not matter.
+/// 3) The meet operator works as follows:
+/// - If current state is assigned, return the current state. (already
+/// a unique layout is assigned. don't change it)
+/// - Otherwise, return the other state.
+
struct SGMap {
private:
WiLayout layout;
@@ -71,8 +93,8 @@ struct SGMap {
SGMap(const WiLayout &layout, const WiData &data)
: layout(layout), data(data) {}
- // Two lattice values are equal if they have `some` layout. The actual
- // content of the layout does not matter.
+ /// Two lattice values are equal if they have `some` layout. The actual
+ /// content of the layout does not matter.
bool operator==(const SGMap &other) const {
return this->isAssigned() == other.isAssigned();
}
@@ -107,9 +129,9 @@ SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) {
return lhs;
}
+/// Since this is a backward analysis, join method is not used.
SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
- // Should not be triggered by this analysis, but required by `Lattice<T>`
- llvm_unreachable("Join should not be triggered by this test");
+ llvm_unreachable("Join should not be triggered by SGMapPropagation.");
}
SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
@@ -124,14 +146,17 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
return SGMap(newLayout, data);
}
+///===----------------------------------------------------------------------===///
+/// SGMapLattice
+///===----------------------------------------------------------------------===///
+
+/// Lattice holding the SGMap for each value.
struct SGMapLattice : public Lattice<SGMap> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice)
using Lattice::Lattice;
};
-/// Helper Functions
-///
-
+/// Helper Function to get the expected layouts for DPAS operands.
static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
int packingFactorForA =
@@ -143,6 +168,11 @@ static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
operandNum == 0 ? packingFactorForA : 1}));
}
+/// Helper Function to get the default layout for a given type. Usually this is,
+/// wi_layout = [1, subgroupSize] and wi_data = [1, 1].
+/// However, the minimum granularity of data access per work item is 16-bits.
+/// So, if the bitwidth of the type is less than 16, we need to pack the data to
+/// 16-bits.
static SGMap getDefaultSgMap(Type ty) {
int packingFactor = 1;
if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
@@ -150,6 +180,15 @@ static SGMap getDefaultSgMap(Type ty) {
return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
}
+///===----------------------------------------------------------------------===///
+/// SGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Backward data flow analysis to propagate the wi_layout and wi_data of each
+/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and
+/// StoreScatter are fixed (known before propagation). Purpose of this analysis
+/// is to propagate those known layouts to all their producers and (other)
+/// consumers.
class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
private:
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<SGMapLattice *> operands,
@@ -177,14 +216,6 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
- void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
-
- void visitCreateDescOp(xegpu::CreateDescOp createDesc,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
-
public:
SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -223,10 +254,6 @@ SGMapPropagation::visitOperation(Operation *op,
visitVectorBitcastOp(bitcast, operands, results);
else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
visitLoadGatherOp(loadGather, operands, results);
- else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
- visitCreateNdDescOp(createNdDesc, operands, results);
- else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
- visitCreateDescOp(createDesc, operands, results);
else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
visitStoreScatterOp(storeScatter, operands, results);
/// All other ops
@@ -248,6 +275,7 @@ SGMapPropagation::visitOperation(Operation *op,
return success();
}
+/// Set the layouts for DPAS A, B, and C operands.
void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
@@ -264,6 +292,7 @@ void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
}
};
+/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
@@ -275,6 +304,8 @@ void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
}
}
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// LoadNdOp.
void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
@@ -283,12 +314,20 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
if (!valueLayout.isAssigned())
return;
SGMap tensorDescLayout = valueLayout;
- if (auto transpose = load.getTranspose())
+ /// LoadNdOp has the transpose effect. However, at the stage of this analyis
+ /// this effect is not expected and should be abstracted away. Emit a warning.
+ /// TODO: Handle this case properly when `order` is introduced in the sg_map.
+ if (auto transpose = load.getTranspose()) {
+ emitWarning(load.getLoc())
+ << "Transpose effect is not expected for LoadNdOp";
tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+ }
/// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
}
+/// For vector::TransposeOp, the layout of the result is transposed and
+/// propagated to the operand.
void SGMapPropagation::visitTransposeOp(
vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
@@ -302,6 +341,8 @@ void SGMapPropagation::visitTransposeOp(
propagateIfChanged(operands[0], operands[0]->meet(newLayout));
}
+/// For vector::BitCastOp, the wi_data of the source layout is changed based on
+/// the bit width of the source and result types.
void SGMapPropagation::visitVectorBitcastOp(
vector::BitCastOp bitcast, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
@@ -336,6 +377,8 @@ void SGMapPropagation::visitVectorBitcastOp(
operands[0]->meet(SGMap(newWiLayout, newWiData)));
}
+/// Propagate the layout of the result to the tensor descriptor and mask
+/// operands in LoadGatherOp.
void SGMapPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
@@ -343,56 +386,47 @@ void SGMapPropagation::visitLoadGatherOp(
/// Need the layout of the value to propagate to the tensor descriptor.
if (!valueLayout.isAssigned())
return;
- /// LoadGatherOp has the transpose effect, so propagate the transposed layout
- /// to the tensor descriptor.
- SGMap tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+
+ SGMap tensorDescLayout;
+ if (load.getTranspose()) {
+ /// LoadGatherOp has the transpose effect. However, at the stage of this
+ /// analyis this effect is not expected and should be abstracted away. Emit
+ /// a warning.
+ /// TODO: Handle this case properly when `order` is introduced in the
+ /// sg_map.
+ emitWarning(load.getLoc())
+ << "Transpose effect is not expected for LoadGatherOp";
+ tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
+ } else
+ tensorDescLayout = valueLayout;
/// Mask operand should have the same layout as the value but with wi_data =
/// [1, 1]
- SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+ SGMap maskLayout = getDefaultSgMap(load.getTensorDescType().getElementType());
/// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
/// Propagate the new layout to the mask operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
}
-void SGMapPropagation::visitCreateNdDescOp(
- xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results) {
- auto descLayout = results[0]->getValue();
- /// Need the layout of the descriptor to propagate to the operands.
- if (!descLayout.isAssigned())
- return;
- /// Propagate the layout to the source operand.
- propagateIfChanged(operands[0], operands[0]->meet(descLayout));
- /// For all other operands propagate the same layout with wi_data = [1, 1]
- SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
- for (size_t i = 1; i < operands.size(); ++i) {
- propagateIfChanged(operands[i], operands[i]->meet(layout));
- }
-}
-
-void SGMapPropagation::visitCreateDescOp(
- xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results) {
- auto descLayout = results[0]->getValue();
- /// Need the layout of the descriptor to propagate to the operands.
- if (!descLayout.isAssigned())
- return;
- /// Propagate the layout to the source operand.
- propagateIfChanged(operands[0], operands[0]->meet(descLayout));
- /// For offset operand propagate the same layout with wi_data = [1, 1]
- SGMap layout = SGMap(descLayout.getLayout(), WiData({1, 1}));
- propagateIfChanged(operands[1], operands[1]->meet(layout));
-}
-
+/// Set the layout for the value, tensor descriptor, and mask operands in
+/// StoreScatterOp.
void SGMapPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
- /// StoreScatterOp has the transpose effect. Value has the regular layout,
- /// while the tensor descriptor has the transposed layout.
auto valueLayout =
getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
- auto storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+ SGMap storeScatterLayout;
+ if (storeScatter.getTranspose()) {
+ /// StoreScatteOp allows transpose effect. However, at the stage of this
+ /// analyis this effect is not expected and should be abstracted away. Emit
+ /// a warning.
+ /// TODO: Handle this case properly when `order` is introduced in the
+ /// sg_map.
+ emitWarning(storeScatter.getLoc())
+ << "Transpose effect is not expected for StoreScatterOp";
+ storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
+ } else
+ storeScatterLayout = valueLayout;
/// Propagate the value layout.
propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
/// Propagate the tensor descriptor layout.
@@ -403,6 +437,11 @@ void SGMapPropagation::visitStoreScatterOp(
namespace {
+///===----------------------------------------------------------------------===///
+/// RunSGMapPropagation
+///===----------------------------------------------------------------------===///
+
+/// Driver class for running the SGMapPropagation analysis.
class RunSGMapPropagation {
public:
RunSGMapPropagation(Operation *op) : target(op) {
@@ -487,13 +526,13 @@ struct XeGPUSubgroupDistributePass final
void XeGPUSubgroupDistributePass::runOnOperation() {
Operation *op = getOperation();
-
RunSGMapPropagation solver(op);
- // Print analysis results
+ // Print the analysis result and exit.
if (printOnly) {
auto &os = llvm::outs();
solver.printAnalysisResult(os);
+ return;
}
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 81c4244091e87..871e6e2f9139e 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -62,7 +62,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
}
// -----
-func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+func.func @test_load_gather_op_1(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -77,7 +77,17 @@ func.func @test_load_gather_op(%a: memref<8x16xf16>, %b : memref<256xf16>, %c :
}
// -----
-func.func @test_store_scatter_op(%c: memref<128xf32>){
+func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>){
+ %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %2 = xegpu.create_tdesc %arg0, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+ %3 = xegpu.load %2, %1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1> -> vector<16xf32>
+ xegpu.store_nd %3, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ return
+}
+
+// -----
+func.func @test_store_scatter_op_1(%c: memref<128xf32>){
%cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
%1 = arith.constant dense<1> : vector<16xi1>
%2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -86,6 +96,15 @@ func.func @test_store_scatter_op(%c: memref<128xf32>){
return
}
+// -----
+func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>){
+ %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %1 = arith.constant dense<1>: vector<16xi1>
+ %2 = xegpu.create_tdesc %arg1, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
+ xegpu.store %arg0, %2, %1 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1>
+ return
+}
+
// -----
func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
%c0 = arith.constant 0 : index
@@ -191,7 +210,7 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
+func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>){
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg3 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -200,30 +219,26 @@ func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<1
scf.yield %arg2 : vector<16x16xf16>
}
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %2 : vector<8x16xf32>
+ xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
}
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1) -> vector<8x16xf32> {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %1 = scf.if %arg3 -> (vector<8x16xf32>) {
- %2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %3 = arith.addf %2, %arg2 : vector<16x16xf16>
- %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- scf.yield %4 : vector<8x16xf32>
- } else {
- %2 = xegpu.dpas %0, %arg2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- scf.yield %2 : vector<8x16xf32>
- }
- return %1 : vector<8x16xf32>
+
+func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+ %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
+ %1 = vector.multi_reduction <add>, %arg0, %0 [0] : vector<16x16xf32> to vector<16xf32>
+ xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ return
}
+
+
// -----
-func.func @test(%arg0: !xegpu.tensor_desc<8x16xi16>, %arg1: !xegpu.tensor_desc<16x16xi16>) -> vector<8x16xf32> {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
- %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xi16> -> vector<16x16xi16>
- %2 = arith.bitcast %0 : vector<8x16xi16> to vector<8x16xf16>
- %3 = arith.bitcast %1 : vector<16x16xi16> to vector<16x16xf16>
- %4 = xegpu.dpas %2, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- return %4 : vector<8x16xf32>
+
+func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+ %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
+ %1 = vector.multi_reduction <add>, %arg0, %0 [1] : vector<16x16xf32> to vector<16xf32>
+ xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ return
}
>From 462b3c09d15254f9d6dbdbe4c11aef9761e56aea Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:00:35 +0000
Subject: [PATCH 09/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 77 +++-
.../XeGPU/subgroup-map-propagation.mlir | 341 ++++++++++++------
2 files changed, 296 insertions(+), 122 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 13030d5c74ea0..de65f285d75b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -134,6 +134,7 @@ SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) {
llvm_unreachable("Join should not be triggered by SGMapPropagation.");
}
+/// Get the transposed layout according to the given permutation.
SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
if (!isAssigned())
return {};
@@ -180,6 +181,11 @@ static SGMap getDefaultSgMap(Type ty) {
return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
}
+/// Helper Function to get the default layout representing constants.
+static SGMap getDefaultSgMap() {
+ return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+}
+
///===----------------------------------------------------------------------===///
/// SGMapPropagation
///===----------------------------------------------------------------------===///
@@ -216,6 +222,14 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
+ void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
+ void visitCreateDescOp(xegpu::CreateDescOp createDesc,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
public:
SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -254,6 +268,10 @@ SGMapPropagation::visitOperation(Operation *op,
visitVectorBitcastOp(bitcast, operands, results);
else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
visitLoadGatherOp(loadGather, operands, results);
+ else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+ visitCreateNdDescOp(createNdDesc, operands, results);
+ else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
+ visitCreateDescOp(createDesc, operands, results);
else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
visitStoreScatterOp(storeScatter, operands, results);
/// All other ops
@@ -318,8 +336,7 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
/// this effect is not expected and should be abstracted away. Emit a warning.
/// TODO: Handle this case properly when `order` is introduced in the sg_map.
if (auto transpose = load.getTranspose()) {
- emitWarning(load.getLoc())
- << "Transpose effect is not expected for LoadNdOp";
+ load.emitWarning("Transpose effect is not expected for LoadNdOp");
tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
}
/// Propagate the new layout to the tensor descriptor operand.
@@ -394,21 +411,53 @@ void SGMapPropagation::visitLoadGatherOp(
/// a warning.
/// TODO: Handle this case properly when `order` is introduced in the
/// sg_map.
- emitWarning(load.getLoc())
- << "Transpose effect is not expected for LoadGatherOp";
+ load.emitWarning("Transpose effect is not expected for LoadGatherOp");
tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
} else
tensorDescLayout = valueLayout;
/// Mask operand should have the same layout as the value but with wi_data =
/// [1, 1]
- SGMap maskLayout = getDefaultSgMap(load.getTensorDescType().getElementType());
+ SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
/// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
/// Propagate the new layout to the mask operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
}
-/// Set the layout for the value, tensor descriptor, and mask operands in
+/// Propagate the layout of the descriptor to the operands in CreateNdDescOp.
+void SGMapPropagation::visitCreateNdDescOp(
+ xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto descLayout = results[0]->getValue();
+ /// Need the layout of the descriptor to propagate to the operands.
+ if (!descLayout.isAssigned())
+ return;
+ /// Propagate the layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+ /// For all other operands propagate the descriptor layout.
+ SGMap layout = getDefaultSgMap();
+ for (size_t i = 1; i < operands.size(); ++i) {
+ propagateIfChanged(operands[i], operands[i]->meet(layout));
+ }
+}
+
+/// Propagate the layout of the descriptor to the source and offset operands in
+/// CreateDescOp.
+void SGMapPropagation::visitCreateDescOp(
+ xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ auto descLayout = results[0]->getValue();
+ /// Need the layout of the descriptor to propagate to the operands.
+ if (!descLayout.isAssigned())
+ return;
+ /// Propagate the layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(descLayout));
+ /// For offset operand propagate the default layout.
+ SGMap layout = getDefaultSgMap();
+ propagateIfChanged(operands[1], operands[1]->meet(layout));
+}
+
+/// Set the layout for the value, tensor descriptor, and mask operands in the
/// StoreScatterOp.
void SGMapPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
@@ -422,8 +471,8 @@ void SGMapPropagation::visitStoreScatterOp(
/// a warning.
/// TODO: Handle this case properly when `order` is introduced in the
/// sg_map.
- emitWarning(storeScatter.getLoc())
- << "Transpose effect is not expected for StoreScatterOp";
+ storeScatter.emitWarning(
+ "Transpose effect is not expected for StoreScatterOp");
storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
} else
storeScatterLayout = valueLayout;
@@ -431,8 +480,9 @@ void SGMapPropagation::visitStoreScatterOp(
propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
/// Propagate the tensor descriptor layout.
propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
- /// Propagate the mask layout.
- propagateIfChanged(operands[2], operands[2]->meet(valueLayout));
+ /// Use default layout for mask operand.
+ auto maskLayout = getDefaultSgMap();
+ propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
}
namespace {
@@ -517,10 +567,9 @@ struct XeGPUSubgroupDistributePass final
}
void runOnOperation() override;
/// Print sg map propagation analysis result and exit for testing purposes.
- Option<bool> printOnly{
- *this, "print-only", llvm::cl::init(false),
- llvm::cl::desc(
- "Print the result of the subgroup map propagation analysis")};
+ Option<bool> printOnly{*this, "print-analysis-only", llvm::cl::init(false),
+ llvm::cl::desc("Print the result of the subgroup map "
+ "propagation analysis and exit.")};
};
} // namespace
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 871e6e2f9139e..c75931cd6a37b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,57 +1,149 @@
-func.func @test_dpas_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
+
+// CHECK: function: test_dpas_op_1:
+// CHECK: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %cst = arith.constant dense<0.0> : vector<8x16xf32>
- %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
- %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
- %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
+
// -----
-func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %c : memref<8x16xi32>) {
+// CHECK: function: test_dpas_op_2:
+// CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
- %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
- xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+ %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+ xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
return
}
// -----
-func.func @test_transpose_op_1(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// CHECK: function: test_transpose_op_1:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %cst = arith.constant dense<0.0> : vector<8x16xf32>
- %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
- %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %3 = xegpu.load_nd %1 {transpose = array<i64: 1, 0>} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
- %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
// -----
-func.func @test_transpose_op_2(%a: memref<8x16xf16>, %b : memref<16x16xf16>, %c : memref<8x16xf32>) {
+// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %cst = arith.constant dense<0.0> : vector<8x16xf32>
- %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
- %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %6 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
- %4 = xegpu.dpas %2, %6, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
- %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+ %5 = xegpu.dpas %2, %4, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
-
-
// -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = arith.extf %[[T1]] : vector<16x16xf16> to vector<16x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = arith.truncf %[[T2]] : vector<16x16xf32> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: Not assigned.
func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -62,75 +154,112 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
}
// -----
-func.func @test_load_gather_op_1(%a: memref<8x16xf16>, %b : memref<256xf16>, %c : memref<8x16xf32>) {
+// CHECK: function: test_load_gather_op_1:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %6 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- %7 = xegpu.load_nd %6 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
- %1 = arith.constant dense<1>: vector<16xi1>
- %2 = xegpu.create_tdesc %b, %0 : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>
- %3 = xegpu.load %2, %1 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16>>, vector<16xi1> -> vector<16x16xf16>
- %4 = xegpu.dpas %7, %3 : vector<8x16xf16>, vector<16x16xf16>-> vector<8x16xf32>
- %5 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %cst_0 = arith.constant dense<true> : vector<16xi1>
+ %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
+ %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+ %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
// -----
-func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>){
- %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
- %1 = arith.constant dense<1>: vector<16xi1>
- %2 = xegpu.create_tdesc %arg0, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
- %3 = xegpu.load %2, %1 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1> -> vector<16xf32>
- xegpu.store_nd %3, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1]] = xegpu.load %[[T0]], %[[CST0]] : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+ %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %cst_0 = arith.constant dense<true> : vector<16xi1>
+ %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ %1 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
return
}
// -----
-func.func @test_store_scatter_op_1(%c: memref<128xf32>){
+func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
%cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
- %1 = arith.constant dense<1> : vector<16xi1>
- %2 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
- %3 = xegpu.create_tdesc %c, %2 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
- xegpu.store %cst, %3, %1 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
+ %cst_0 = arith.constant dense<true> : vector<16xi1>
+ %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+ xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
return
}
// -----
-func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>){
- %0 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
- %1 = arith.constant dense<1>: vector<16xi1>
- %2 = xegpu.create_tdesc %arg1, %0 : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>
- xegpu.store %arg0, %2, %1 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 1>>, vector<16xi1>
+func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+ %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+ %cst_0 = arith.constant dense<true> : vector<16xi1>
+ %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ xegpu.store %arg0, %0, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
return
}
// -----
-func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %c : memref<8x16xi32>) {
+func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
- %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
- %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
- %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
- %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
- %6 = vector.bitcast %4 : vector<8x16xi16> to vector<8x32xi8>
- %0 = xegpu.dpas %6, %5 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
- %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
- xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+ %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x32xi8>
+ %5 = xegpu.dpas %4, %3 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+ xegpu.store_nd %5, %6 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
return
}
// -----
-func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %c : memref<8x16xf32>) {
+func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
- %2 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
- %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
- %4 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
- %5 = xegpu.load_nd %3 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
- %6 = vector.bitcast %4 : vector<8x32xi8> to vector<8x16xf16>
- %7 = vector.bitcast %5 : vector<16x32xi8> to vector<16x16xf16>
- %0 = xegpu.dpas %6, %7 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- %1 = xegpu.create_nd_tdesc %c [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %0, %1 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+ %4 = vector.bitcast %2 : vector<8x32xi8> to vector<8x16xf16>
+ %5 = vector.bitcast %3 : vector<16x32xi8> to vector<16x16xf16>
+ %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %6, %7 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
@@ -141,7 +270,7 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
%2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%3 = arith.addf %1, %2 : vector<16x16xf16>
%4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
@@ -152,34 +281,34 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
%cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
%2 = arith.addf %1, %cst : vector<16x16xf16>
%3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
return
}
// -----
-func.func @test_for_op(%a: memref<8x128xf16>, %b : memref<128x16xf16>, %c : memref<8x16xf32>) {
+func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c16 = arith.constant 16 : index
- %0 = xegpu.create_nd_tdesc %a[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
- %1 = xegpu.create_nd_tdesc %b[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
- %2 = arith.constant dense<0.0> : vector<8x16xf32>
- %3:3 = scf.for %k = %c0 to %c128 step %c16 iter_args(%arg0 = %0, %arg1 = %1, %arg2 = %2) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
- %4 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %5 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- %6 = xegpu.dpas %4, %5, %arg2 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
- %7 = xegpu.update_nd_offset %arg0, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
- %8 = xegpu.update_nd_offset %arg1, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+ %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+ %2:3 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %cst) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) {
+ %4 = xegpu.load_nd %arg4 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %5 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %6 = xegpu.dpas %4, %5, %arg6 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %7 = xegpu.update_nd_offset %arg4, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16>
+ %8 = xegpu.update_nd_offset %arg5, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16>
scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>
}
- %9 = xegpu.create_nd_tdesc %c[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %3#2, %9 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %2#2, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
// -----
-func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>){
+func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -189,12 +318,12 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
scf.yield %3 : vector<16x16xf16>
}
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
// -----
-func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>){
+func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -204,13 +333,13 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
scf.yield %3 : vector<16x16xf16>
}
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
- xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
return
}
// -----
-func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>){
+func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg3 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -219,26 +348,22 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
scf.yield %arg2 : vector<16x16xf16>
}
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
return
}
// -----
-
func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
- %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
- %1 = vector.multi_reduction <add>, %arg0, %0 [0] : vector<16x16xf32> to vector<16xf32>
- xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+ %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
+ xegpu.store_nd %0, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
return
}
-
-
// -----
-
func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
- %0 = arith.constant dense<0.000000e+00> : vector<16xf32>
- %1 = vector.multi_reduction <add>, %arg0, %0 [1] : vector<16x16xf32> to vector<16xf32>
- xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
+ %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
+ xegpu.store_nd %0, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
return
}
>From 27f011fbedc582bb85f3555bf38b00247906a6bf Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:35:52 +0000
Subject: [PATCH 10/27] save work
---
.../XeGPU/subgroup-map-propagation.mlir | 138 +++++++++++++++++-
1 file changed, 136 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index c75931cd6a37b..d9d1ac80169ea 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -216,6 +216,16 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
}
// -----
+// CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
%cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
%cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -226,6 +236,16 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
}
// -----
+// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -235,7 +255,29 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
}
// -----
-func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+// CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x16xi16> to vector<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op : %[[T5:.*]] = xegpu.dpas %[[T4]], %[[T3]] : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
%1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -249,7 +291,31 @@ func.func @test_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %
}
// -----
-func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
+// CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x32xi8> to vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T5:.*]] = vector.bitcast %[[T3]] : vector<16x32xi8> to vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
%1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -264,6 +330,22 @@ func.func @test_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %a
}
// -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = arith.addf %[[T1]], %[[T2]] : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -275,6 +357,24 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
}
// -----
+// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 3
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = arith.addf %[[T1]], %[[CST]] : vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -287,6 +387,40 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
}
// -----
+// CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %{{.*}} = arith.constant 128 : index
+// CHECK-NEXT: sg_map for result #0: Not assigned.
+// CHECK-NEXT: op : %{{.*}} = arith.constant 16 : index
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T5:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T8:.*]] = xegpu.update_nd_offset %{{.8}} : !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : scf.for
+// CHECK-NEXT: sg_map for result #0: Not assigned.
+// CHECK-NEXT: sg_map for result #1: Not assigned.
+// CHECK-NEXT: sg_map for result #2: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
>From 898b6195b8676a0e3db141cc5e6077c22d0c1ec3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 04:45:17 +0000
Subject: [PATCH 11/27] save work
---
.../XeGPU/subgroup-map-propagation.mlir | 78 ++++++++++++++++++-
1 file changed, 77 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index d9d1ac80169ea..83a17dd9d1a7e 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -413,7 +413,7 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op : %[[T8:.*]] = xegpu.update_nd_offset %{{.8}} : !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: op : %[[T8:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<16x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
// CHECK-NEXT: op : scf.for
// CHECK-NEXT: sg_map for result #0: Not assigned.
@@ -442,6 +442,25 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
}
// -----
+// CHECK: function: test_if_op_1:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
+// CHECK-NEXT: sg_map : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
@@ -457,6 +476,27 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
}
// -----
+// CHECK: function: test_if_op_2:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 2
+// CHECK-NEXT: sg_map : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 3
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 4
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T4:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
@@ -473,6 +513,25 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
}
// -----
+// CHECK: function: test_if_op_3:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf16>' at index: 2
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 3
+// CHECK-NEXT: sg_map : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 4
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : scf.if
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg3 -> (vector<16x16xf16>) {
@@ -487,6 +546,15 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
}
// -----
+// CHECK: function: test_vector_reduction_1:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -495,6 +563,14 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
}
// -----
+// CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
+// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
+// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
>From 83404020a7459ca2c99cfb8e15027c2f31968213 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 19:55:40 +0000
Subject: [PATCH 12/27] save work
---
.../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index de65f285d75b4..aba0586f30cfd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -238,13 +238,13 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
LogicalResult visitOperation(Operation *op, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) override;
- void visitBranchOperand(OpOperand &operand) override{};
+ void visitBranchOperand(OpOperand &operand) override {};
- void visitCallOperand(OpOperand &operand) override{};
+ void visitCallOperand(OpOperand &operand) override {};
void visitExternalCall(CallOpInterface call,
ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results) override{};
+ ArrayRef<const SGMapLattice *> results) override {};
void setToExitState(SGMapLattice *lattice) override {
(void)lattice->meet(SGMap());
>From 549e8781d8899f7056a19548a3c6848639be8587 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 20:07:39 +0000
Subject: [PATCH 13/27] save work
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index aba0586f30cfd..d4618bf0e24d6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -144,7 +144,7 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
newLayout.layout.push_back(layout.layout[idx]);
newData.layout.push_back(data.layout[idx]);
}
- return SGMap(newLayout, data);
+ return SGMap(newLayout, newData);
}
///===----------------------------------------------------------------------===///
>From abca365b5bf0851b8a31885395ce576df15b1bf3 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 20:12:56 +0000
Subject: [PATCH 14/27] save work
---
.../Dialect/XeGPU/subgroup-map-propagation.mlir | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 83a17dd9d1a7e..92e37a20bf555 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -58,7 +58,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
@@ -68,7 +68,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -94,7 +94,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
@@ -104,11 +104,11 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: op : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
// CHECK-NEXT: op : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
@@ -158,7 +158,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
@@ -172,7 +172,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [2, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
>From 5bd5db50d8bdd6c3802e270741b9e0e1d744bd34 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 7 Mar 2025 23:14:23 +0000
Subject: [PATCH 15/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 113 ++++++++++++++----
.../XeGPU/subgroup-map-propagation.mlir | 60 +++++-----
2 files changed, 117 insertions(+), 56 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d4618bf0e24d6..7136993d0e968 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
@@ -157,6 +158,8 @@ struct SGMapLattice : public Lattice<SGMap> {
using Lattice::Lattice;
};
+/// Helper Functions
+
/// Helper Function to get the expected layouts for DPAS operands.
static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
@@ -174,16 +177,33 @@ static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
/// However, the minimum granularity of data access per work item is 16-bits.
/// So, if the bitwidth of the type is less than 16, we need to pack the data to
/// 16-bits.
-static SGMap getDefaultSgMap(Type ty) {
- int packingFactor = 1;
- if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
- packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
- return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+// static SGMap getDefaultSgMap(Type ty, unsigned rank) {
+// int packingFactor = 1;
+// if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
+// packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
+// return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
+// }
+
+/// Helper Function to get the default layout for uniform values like constants.
+static SGMap getDefaultSgMap(unsigned rank) {
+ assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
+ if (rank == 1)
+ return SGMap(WiLayout({subgroupSize}), WiData({1}));
+ return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
}
-/// Helper Function to get the default layout representing constants.
-static SGMap getDefaultSgMap() {
- return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
+static SGMap getDefaultSgMap(VectorType vectorTy) {
+ /// Expecting a 1D or 2D vector.
+ assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
+ "Expected 1D or 2D vector.");
+ /// If the rank is 1, then return default layout for 1D vector.
+ if (vectorTy.getRank() == 1)
+ return getDefaultSgMap(1);
+ int packingFactor = 1;
+ auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+ if (bitwidth < packedASizeInBits)
+ packingFactor = packedBSizeInBits / bitwidth;
+ return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
}
///===----------------------------------------------------------------------===///
@@ -230,6 +250,14 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
+ void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
+ void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
+ ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results);
+
public:
SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
@@ -274,6 +302,10 @@ SGMapPropagation::visitOperation(Operation *op,
visitCreateDescOp(createDesc, operands, results);
else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
visitStoreScatterOp(storeScatter, operands, results);
+ else if (auto updateNdOffset = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+ visitUpdateNdOffsetOp(updateNdOffset, operands, results);
+ else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
+ visitVectorMultiReductionOp(reduction, operands, results);
/// All other ops
else {
for (const SGMapLattice *r : results) {
@@ -293,6 +325,43 @@ SGMapPropagation::visitOperation(Operation *op,
return success();
}
+void SGMapPropagation::visitVectorMultiReductionOp(
+ vector::MultiDimReductionOp reduction, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ /// The layout of the result must be present.
+ auto resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ /// We only consider 2D -> 1D reductions at this point.
+ assert(resultLayout.getLayout().size() == 1 &&
+ "Expected 1D layout for reduction result.");
+ /// Given that the result is 1D, the layout of the operand should be 2D with
+ /// default layout.
+ auto operandLayout = getDefaultSgMap(2);
+ operandLayout.print(llvm::outs());
+ propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+ /// Accumulator should have the same layout as the result.
+ propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+}
+
+/// Propagate the layout of the result tensor to the source tensor descriptor in
+/// UpdateNdOffsetOp.
+void SGMapPropagation::visitUpdateNdOffsetOp(
+ xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef<SGMapLattice *> operands,
+ ArrayRef<const SGMapLattice *> results) {
+ /// The layout of the result must be present.
+ auto resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ /// Propagate the layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+ /// For all other operands use 1D default layout.
+ SGMap layout = getDefaultSgMap(1);
+ for (size_t i = 1; i < operands.size(); ++i) {
+ propagateIfChanged(operands[i], operands[i]->meet(layout));
+ }
+}
+
/// Set the layouts for DPAS A, B, and C operands.
void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
ArrayRef<SGMapLattice *> operands,
@@ -314,8 +383,7 @@ void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
- auto storeLayout =
- getDefaultSgMap(store.getTensorDescType().getElementType());
+ auto storeLayout = getDefaultSgMap(store.getValueType());
/// Both operands should have the same layout
for (SGMapLattice *operand : operands) {
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -334,7 +402,6 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
SGMap tensorDescLayout = valueLayout;
/// LoadNdOp has the transpose effect. However, at the stage of this analyis
/// this effect is not expected and should be abstracted away. Emit a warning.
- /// TODO: Handle this case properly when `order` is introduced in the sg_map.
if (auto transpose = load.getTranspose()) {
load.emitWarning("Transpose effect is not expected for LoadNdOp");
tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
@@ -409,15 +476,12 @@ void SGMapPropagation::visitLoadGatherOp(
/// LoadGatherOp has the transpose effect. However, at the stage of this
/// analyis this effect is not expected and should be abstracted away. Emit
/// a warning.
- /// TODO: Handle this case properly when `order` is introduced in the
- /// sg_map.
load.emitWarning("Transpose effect is not expected for LoadGatherOp");
tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
} else
tensorDescLayout = valueLayout;
- /// Mask operand should have the same layout as the value but with wi_data =
- /// [1, 1]
- SGMap maskLayout = SGMap(valueLayout.getLayout(), WiData({1, 1}));
+ /// Mask operand should have 1D default layout.
+ auto maskLayout = getDefaultSgMap(1);
/// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
/// Propagate the new layout to the mask operand.
@@ -434,8 +498,8 @@ void SGMapPropagation::visitCreateNdDescOp(
return;
/// Propagate the layout to the source operand.
propagateIfChanged(operands[0], operands[0]->meet(descLayout));
- /// For all other operands propagate the descriptor layout.
- SGMap layout = getDefaultSgMap();
+ /// For all other operands use 1D default layout.
+ SGMap layout = getDefaultSgMap(1);
for (size_t i = 1; i < operands.size(); ++i) {
propagateIfChanged(operands[i], operands[i]->meet(layout));
}
@@ -452,8 +516,8 @@ void SGMapPropagation::visitCreateDescOp(
return;
/// Propagate the layout to the source operand.
propagateIfChanged(operands[0], operands[0]->meet(descLayout));
- /// For offset operand propagate the default layout.
- SGMap layout = getDefaultSgMap();
+ /// For offset operand propagate 1D default layout.
+ SGMap layout = getDefaultSgMap(1);
propagateIfChanged(operands[1], operands[1]->meet(layout));
}
@@ -462,15 +526,12 @@ void SGMapPropagation::visitCreateDescOp(
void SGMapPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
- auto valueLayout =
- getDefaultSgMap(storeScatter.getTensorDescType().getElementType());
+ auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
SGMap storeScatterLayout;
if (storeScatter.getTranspose()) {
/// StoreScatteOp allows transpose effect. However, at the stage of this
/// analyis this effect is not expected and should be abstracted away. Emit
/// a warning.
- /// TODO: Handle this case properly when `order` is introduced in the
- /// sg_map.
storeScatter.emitWarning(
"Transpose effect is not expected for StoreScatterOp");
storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
@@ -480,8 +541,8 @@ void SGMapPropagation::visitStoreScatterOp(
propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
/// Propagate the tensor descriptor layout.
propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
- /// Use default layout for mask operand.
- auto maskLayout = getDefaultSgMap();
+ /// Use default 1D layout for mask operand.
+ auto maskLayout = getDefaultSgMap(1);
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 92e37a20bf555..7d8bd20391463 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -2,7 +2,7 @@
// CHECK: function: test_dpas_op_1:
// CHECK: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -40,7 +40,7 @@ func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -62,7 +62,7 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -98,7 +98,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -162,15 +162,15 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2]
// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
@@ -195,17 +195,17 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
// -----
// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T1]] = xegpu.load %[[T0]], %[[CST0]] : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -221,9 +221,9 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
@@ -237,15 +237,15 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
// -----
// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%cst_0 = arith.constant dense<true> : vector<16xi1>
@@ -262,7 +262,7 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -298,7 +298,7 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -394,11 +394,11 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %{{.*}} = arith.constant 128 : index
// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
@@ -550,11 +550,11 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -566,11 +566,11 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
// CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
>From 5820c15c393a5b29e5aa8533627de51e480b6275 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 18:54:34 +0000
Subject: [PATCH 16/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 56 +++++--------
.../XeGPU/subgroup-map-propagation.mlir | 78 ++++++++++---------
2 files changed, 63 insertions(+), 71 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7136993d0e968..14a9f5954edcf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
@@ -242,10 +243,6 @@ class SGMapPropagation : public SparseBackwardDataFlowAnalysis<SGMapLattice> {
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
- void visitCreateNdDescOp(xegpu::CreateNdDescOp createNdDesc,
- ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results);
-
void visitCreateDescOp(xegpu::CreateDescOp createDesc,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results);
@@ -296,8 +293,6 @@ SGMapPropagation::visitOperation(Operation *op,
visitVectorBitcastOp(bitcast, operands, results);
else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
visitLoadGatherOp(loadGather, operands, results);
- else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
- visitCreateNdDescOp(createNdDesc, operands, results);
else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
visitCreateDescOp(createDesc, operands, results);
else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
@@ -306,6 +301,10 @@ SGMapPropagation::visitOperation(Operation *op,
visitUpdateNdOffsetOp(updateNdOffset, operands, results);
else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
visitVectorMultiReductionOp(reduction, operands, results);
+ /// No need to propagate the layout to operands in CreateNdDescOp because they
+ /// are scalars (offsets, sizes, etc.).
+ else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
+ return success();
/// All other ops
else {
for (const SGMapLattice *r : results) {
@@ -355,11 +354,6 @@ void SGMapPropagation::visitUpdateNdOffsetOp(
return;
/// Propagate the layout to the source operand.
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
- /// For all other operands use 1D default layout.
- SGMap layout = getDefaultSgMap(1);
- for (size_t i = 1; i < operands.size(); ++i) {
- propagateIfChanged(operands[i], operands[i]->meet(layout));
- }
}
/// Set the layouts for DPAS A, B, and C operands.
@@ -403,7 +397,8 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
/// LoadNdOp has the transpose effect. However, at the stage of this analyis
/// this effect is not expected and should be abstracted away. Emit a warning.
if (auto transpose = load.getTranspose()) {
- load.emitWarning("Transpose effect is not expected for LoadNdOp");
+ load.emitWarning("Transpose effect is not expected for LoadNdOp at "
+ "SGMapPropagation stage.");
tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
}
/// Propagate the new layout to the tensor descriptor operand.
@@ -476,7 +471,8 @@ void SGMapPropagation::visitLoadGatherOp(
/// LoadGatherOp has the transpose effect. However, at the stage of this
/// analyis this effect is not expected and should be abstracted away. Emit
/// a warning.
- load.emitWarning("Transpose effect is not expected for LoadGatherOp");
+ load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
+ "SGMapPropagation stage.");
tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
} else
tensorDescLayout = valueLayout;
@@ -488,24 +484,7 @@ void SGMapPropagation::visitLoadGatherOp(
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
}
-/// Propagate the layout of the descriptor to the operands in CreateNdDescOp.
-void SGMapPropagation::visitCreateNdDescOp(
- xegpu::CreateNdDescOp createNdDesc, ArrayRef<SGMapLattice *> operands,
- ArrayRef<const SGMapLattice *> results) {
- auto descLayout = results[0]->getValue();
- /// Need the layout of the descriptor to propagate to the operands.
- if (!descLayout.isAssigned())
- return;
- /// Propagate the layout to the source operand.
- propagateIfChanged(operands[0], operands[0]->meet(descLayout));
- /// For all other operands use 1D default layout.
- SGMap layout = getDefaultSgMap(1);
- for (size_t i = 1; i < operands.size(); ++i) {
- propagateIfChanged(operands[i], operands[i]->meet(layout));
- }
-}
-
-/// Propagate the layout of the descriptor to the source and offset operands in
+/// Propagate the layout of the descriptor to the vector offset operand in
/// CreateDescOp.
void SGMapPropagation::visitCreateDescOp(
xegpu::CreateDescOp createDesc, ArrayRef<SGMapLattice *> operands,
@@ -514,8 +493,6 @@ void SGMapPropagation::visitCreateDescOp(
/// Need the layout of the descriptor to propagate to the operands.
if (!descLayout.isAssigned())
return;
- /// Propagate the layout to the source operand.
- propagateIfChanged(operands[0], operands[0]->meet(descLayout));
/// For offset operand propagate 1D default layout.
SGMap layout = getDefaultSgMap(1);
propagateIfChanged(operands[1], operands[1]->meet(layout));
@@ -526,14 +503,23 @@ void SGMapPropagation::visitCreateDescOp(
void SGMapPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
+ /// Currently, for 2D StoreScatterOp we expect that the height dimension of
+ /// the tensor descriptor is evenly divisible by the subgroup size.
+ /// TODO: Add support for other 2D shapes.
+ auto tdescShape = storeScatter.getTensorDescType().getShape();
+ if (tdescShape.size() > 1 && tdescShape[0] % subgroupSize != 0) {
+ storeScatter.emitError("Height dimension of the tensor descriptor should "
+ "be evenly divisible by the subgroup size.");
+ return;
+ }
auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
SGMap storeScatterLayout;
if (storeScatter.getTranspose()) {
/// StoreScatteOp allows transpose effect. However, at the stage of this
/// analyis this effect is not expected and should be abstracted away. Emit
/// a warning.
- storeScatter.emitWarning(
- "Transpose effect is not expected for StoreScatterOp");
+ storeScatter.emitWarning("Transpose effect is not expected for "
+ "StoreScatterOp at SGMapPropagation stage.");
storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
} else
storeScatterLayout = valueLayout;
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 7d8bd20391463..1f16c0f738af9 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -1,8 +1,14 @@
// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s
-// CHECK: function: test_dpas_op_1:
-// CHECK: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK: function: test_dpas_f16:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK-NEXT: sg_map : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
+// CHECK-NEXT: sg_map : Not assigned.
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
+// CHECK-NEXT: sg_map : Not assigned.
+// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -17,7 +23,7 @@
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -32,20 +38,20 @@ func.func @test_dpas_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %ar
// -----
-// CHECK: function: test_dpas_op_2:
+// CHECK: function: test_dpas_i8:
// CHECK-NEXT: argument: <block argument> of type 'vector<8x32xi8>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 2]
// CHECK-NEXT: argument: <block argument> of type 'vector<32x16xi8>' at index: 1
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [4, 1]
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
+func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
%1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
@@ -56,13 +62,13 @@ func.func @test_dpas_op_2(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2
// -----
// CHECK: function: test_transpose_op_1:
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -92,13 +98,13 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
// -----
// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -156,13 +162,13 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
// -----
// CHECK: function: test_load_gather_op_1:
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 2]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -195,7 +201,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
// -----
// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -217,7 +223,7 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
// -----
// CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [16, 1], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense<true> : vector<16xi1>
@@ -239,7 +245,7 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[CST1:.*]] = arith.constant dense<true> : vector<16xi1>
@@ -256,13 +262,13 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
// -----
// CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi32>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -292,13 +298,13 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
// -----
// CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 2]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [4, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -388,17 +394,17 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
// -----
// CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf32>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
+// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 128 : index
// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %{{.*}} = arith.constant 16 : index
-// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
+// CHECK-NEXT: sg_map for result #0: Not assigned.
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16>
>From f7b7bdfc5fb2fbf64eb4db05c0918e912ce30825 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 19:08:55 +0000
Subject: [PATCH 17/27] save work
---
.../XeGPU/subgroup-map-propagation.mlir | 104 +++++++-----------
1 file changed, 41 insertions(+), 63 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index 1f16c0f738af9..1ae4348af33e6 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -60,7 +60,7 @@ func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2:
}
// -----
-// CHECK: function: test_transpose_op_1:
+// CHECK: function: test_load_with_transpose_effect:
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
@@ -83,7 +83,7 @@ func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2:
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -97,7 +97,8 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
}
// -----
-// CHECK: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
+// CHECK: function: test_vector_transpose:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<16x16xf16>' at index: 1
// CHECK-NEXT: sg_map : Not assigned.
@@ -121,7 +122,7 @@ func.func @test_transpose_op_1(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -136,7 +137,8 @@ func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
}
// -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_extf_truncf:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
@@ -150,7 +152,7 @@ func.func @test_transpose_op_2(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: Not assigned.
-func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
+func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32>
@@ -160,7 +162,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
}
// -----
-// CHECK: function: test_load_gather_op_1:
+// CHECK: function: test_load_gather_with_transpose_effect:
// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<256xf16>' at index: 1
@@ -185,7 +187,7 @@ func.func @test_extf_truncf_op(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegp
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
+func.func @test_load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
@@ -200,6 +202,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
}
// -----
+// CHECK: function: test_load_gather_1d:
// CHECK: argument: <block argument> of type 'memref<256xf32>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
@@ -212,7 +215,7 @@ func.func @test_load_gather_op_1(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T1]] = xegpu.load %[[T0]], %[[CST0]] : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%cst_0 = arith.constant dense<true> : vector<16xi1>
%0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -222,7 +225,8 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
}
// -----
-// CHECK: argument: <block argument> of type 'memref<128xf32>' at index: 0
+// CHECK: function: test_store_scatter_with_transpose_effect:
+// CHECK-NEXT: argument: <block argument> of type 'memref<128xf32>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
@@ -232,7 +236,7 @@ func.func @test_load_gather_op_2(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_de
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1]
-func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
+func.func @test_store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
%cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
%cst_0 = arith.constant dense<true> : vector<16xi1>
%cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -242,7 +246,8 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
}
// -----
-// CHECK: argument: <block argument> of type 'vector<16xf32>' at index: 0
+// CHECK: function: test_store_scatter_1d:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16xf32>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
// CHECK-NEXT: argument: <block argument> of type 'memref<256xf32>' at index: 1
// CHECK-NEXT: sg_map : Not assigned.
@@ -252,7 +257,7 @@ func.func @test_store_scatter_op_1(%arg0: memref<128xf32>) {
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
+func.func @test_store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
%cst_0 = arith.constant dense<true> : vector<16xi1>
%0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
@@ -261,7 +266,8 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
}
// -----
-// CHECK: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
+// CHECK: function: test_vector_bitcast_i16_to_i8:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x16xi16>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<32x16xi8>' at index: 1
// CHECK-NEXT: sg_map : Not assigned.
@@ -283,7 +289,7 @@ func.func @test_store_scatter_op_2(%arg0: vector<16xf32>, %arg1: memref<256xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
+func.func @test_vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
%1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8>
@@ -297,7 +303,8 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
}
// -----
-// CHECK: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
+// CHECK: function: test_vector_bitcast_i8_to_f16:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x32xi8>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<16x32xi8>' at index: 1
// CHECK-NEXT: sg_map : Not assigned.
@@ -321,7 +328,7 @@ func.func @test_vector_bitcast_op_1(%arg0: memref<8x16xi16>, %arg1: memref<32x16
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
+func.func @test_vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8>
%1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8>
@@ -336,7 +343,8 @@ func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32x
}
// -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_binary_op_one_use:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
@@ -352,7 +360,7 @@ func.func @test_vector_bitcast_op_2(%arg0: memref<8x32xi8>, %arg1: memref<16x32x
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
+func.func @test_binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -363,7 +371,8 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
}
// -----
-// CHECK: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
+// CHECK: function: test_binary_op_multiple_uses:
+// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
@@ -381,7 +390,7 @@ func.func @test_binary_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
+func.func @test_binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
%cst = arith.constant dense<1.000000e+00> : vector<16x16xf16>
@@ -393,7 +402,8 @@ func.func @test_binary_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
}
// -----
-// CHECK: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
+// CHECK: function: test_for_op:
+// CHECK-NEXT: argument: <block argument> of type 'memref<8x128xf16>' at index: 0
// CHECK-NEXT: sg_map : Not assigned.
// CHECK-NEXT: argument: <block argument> of type 'memref<128x16xf16>' at index: 1
// CHECK-NEXT: sg_map : Not assigned.
@@ -448,7 +458,7 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
}
// -----
-// CHECK: function: test_if_op_1:
+// CHECK: function: test_if_single_use:
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
@@ -467,7 +477,7 @@ func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
+func.func @test_if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -482,7 +492,7 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
}
// -----
-// CHECK: function: test_if_op_2:
+// CHECK: function: test_if_multiple_uses:
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
@@ -503,7 +513,7 @@ func.func @test_if_op_1(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
+func.func @test_if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) {
%0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%1 = scf.if %arg2 -> (vector<16x16xf16>) {
%3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
@@ -519,40 +529,7 @@ func.func @test_if_op_2(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
}
// -----
-// CHECK: function: test_if_op_3:
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf16>' at index: 0
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16x16xf16>' at index: 1
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf16>' at index: 2
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: argument: <block argument> of type 'i1' at index: 3
-// CHECK-NEXT: sg_map : Not assigned.
-// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<8x16xf32>' at index: 4
-// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op : scf.if
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1]
-// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1]
-func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: vector<16x16xf16>, %arg3: i1, %arg4: !xegpu.tensor_desc<8x16xf32>) {
- %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- %1 = scf.if %arg3 -> (vector<16x16xf16>) {
- %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
- scf.yield %3 : vector<16x16xf16>
- } else {
- scf.yield %arg2 : vector<16x16xf16>
- }
- %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
- xegpu.store_nd %2, %arg4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
- return
-}
-
-// -----
-// CHECK: function: test_vector_reduction_1:
+// CHECK: function: test_vector_outer_reduction:
// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
@@ -561,7 +538,7 @@ func.func @test_if_op_3(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tenso
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
xegpu.store_nd %0, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
@@ -569,7 +546,8 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
}
// -----
-// CHECK: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
+// CHECK: function: test_vector_inner_reduction:
+// CHECK-NEXT: argument: <block argument> of type 'vector<16x16xf32>' at index: 0
// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1]
// CHECK-NEXT: argument: <block argument> of type '!xegpu.tensor_desc<16xf32>' at index: 1
// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1]
@@ -577,7 +555,7 @@ func.func @test_vector_reduction_1(%arg0: vector<16x16xf32>, %arg1: !xegpu.tenso
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction <add>, %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32>
// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1]
-func.func @test_vector_reduction_2(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
+func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<16xf32>
%0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>
xegpu.store_nd %0, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
>From f64581a7dfdf9dc15e4d8e5d4305e89acae692fd Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 20:34:22 +0000
Subject: [PATCH 18/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 106 ++++++++++--------
1 file changed, 62 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 14a9f5954edcf..d5e4f6532955e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -32,9 +32,15 @@ namespace xegpu {
using namespace mlir;
using namespace mlir::dataflow;
-constexpr unsigned subgroupSize = 16;
-constexpr unsigned packedASizeInBits = 16;
-constexpr unsigned packedBSizeInBits = 32;
+/// HW dependent constants.
+/// TODO: These constants should be queried from the uArch interface.
+constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
+/// If DPAS A or B operands have low precision element types they must be packed
+/// according to the following sizes.
+constexpr unsigned packedSizeInBitsForDefault =
+ 16; // Minimum packing size per register for DPAS A.
+constexpr unsigned packedSizeInBitsForDpasB =
+ 32; // Minimum packing size per register for DPAS B.
namespace {
@@ -51,7 +57,7 @@ struct Layout {
Layout(std::initializer_list<int64_t> list) : layout(list) {}
void print(llvm::raw_ostream &os) const;
size_t size() const { return layout.size(); }
- int64_t operator[](size_t idx) const { return layout[idx]; }
+ int64_t operator[](size_t idx) const;
};
void Layout::print(llvm::raw_ostream &os) const {
@@ -60,6 +66,11 @@ void Layout::print(llvm::raw_ostream &os) const {
os << "]";
}
+int64_t Layout::operator[](size_t idx) const {
+ assert(idx < layout.size() && "Index out of bounds.");
+ return layout[idx];
+}
+
/// WiLayout represents the layout of work items within a subgroup when it
/// accesses some value. WiData represents the layout of data owned by each work
/// item.
@@ -86,14 +97,14 @@ using WiData = Layout;
struct SGMap {
private:
- WiLayout layout;
- WiData data;
+ WiLayout wiLayout;
+ WiData wiData;
public:
SGMap() = default;
SGMap(const SGMap &other) = default;
SGMap(const WiLayout &layout, const WiData &data)
- : layout(layout), data(data) {}
+ : wiLayout(layout), wiData(data) {}
/// Two lattice values are equal if they have `some` layout. The actual
/// content of the layout does not matter.
@@ -107,20 +118,20 @@ struct SGMap {
void print(raw_ostream &os) const;
- bool isAssigned() const { return layout.size() > 0 && data.size() > 0; }
+ bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; }
SGMap getTransposedLayout(ArrayRef<int64_t> permutation) const;
- const WiLayout &getLayout() const { return layout; }
- const WiData &getData() const { return data; }
+ const WiLayout &getLayout() const { return wiLayout; }
+ const WiData &getData() const { return wiData; }
};
void SGMap::print(raw_ostream &os) const {
if (isAssigned()) {
os << "wi_layout: ";
- layout.print(os);
+ wiLayout.print(os);
os << ", wi_data: ";
- data.print(os);
+ wiData.print(os);
} else
os << "Not assigned.";
}
@@ -143,8 +154,8 @@ SGMap SGMap::getTransposedLayout(ArrayRef<int64_t> permutation) const {
WiLayout newLayout;
WiData newData;
for (auto idx : permutation) {
- newLayout.layout.push_back(layout.layout[idx]);
- newData.layout.push_back(data.layout[idx]);
+ newLayout.layout.push_back(wiLayout.layout[idx]);
+ newData.layout.push_back(wiData.layout[idx]);
}
return SGMap(newLayout, newData);
}
@@ -159,33 +170,14 @@ struct SGMapLattice : public Lattice<SGMap> {
using Lattice::Lattice;
};
-/// Helper Functions
-
-/// Helper Function to get the expected layouts for DPAS operands.
-static SGMap getSGMapForDPASOperand(Type operandTy, unsigned operandNum) {
- int packingFactorForB = packedBSizeInBits / operandTy.getIntOrFloatBitWidth();
- int packingFactorForA =
- operandTy.getIntOrFloatBitWidth() < packedBSizeInBits
- ? packedASizeInBits / operandTy.getIntOrFloatBitWidth()
- : 1;
- return SGMap(WiLayout({1, subgroupSize}),
- WiData({operandNum == 1 ? packingFactorForB : 1,
- operandNum == 0 ? packingFactorForA : 1}));
-}
-
-/// Helper Function to get the default layout for a given type. Usually this is,
-/// wi_layout = [1, subgroupSize] and wi_data = [1, 1].
-/// However, the minimum granularity of data access per work item is 16-bits.
-/// So, if the bitwidth of the type is less than 16, we need to pack the data to
-/// 16-bits.
-// static SGMap getDefaultSgMap(Type ty, unsigned rank) {
-// int packingFactor = 1;
-// if (ty.getIntOrFloatBitWidth() < packedASizeInBits)
-// packingFactor = packedBSizeInBits / ty.getIntOrFloatBitWidth();
-// return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
-// }
+/// Helper Functions to get default layouts. A `default layout` is a layout that
+/// is assigned to a value when the layout is not fixed by some anchor operation
+/// (like DPAS). This is the natural layout work items are arranged in a
+/// subgroup.
/// Helper Function to get the default layout for uniform values like constants.
+/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
+/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
static SGMap getDefaultSgMap(unsigned rank) {
assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
if (rank == 1)
@@ -193,20 +185,46 @@ static SGMap getDefaultSgMap(unsigned rank) {
return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
}
+/// Helper to get the default layout for a vector type.
static SGMap getDefaultSgMap(VectorType vectorTy) {
/// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
+ /// Expecting int or float element type.
+ assert(vectorTy.getElementType().isIntOrFloat() &&
+ "Expected int or float element type.");
/// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
return getDefaultSgMap(1);
+ /// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
- if (bitwidth < packedASizeInBits)
- packingFactor = packedBSizeInBits / bitwidth;
+ if (bitwidth < packedSizeInBitsForDefault)
+ packingFactor = packedSizeInBitsForDefault / bitwidth;
return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor}));
}
+/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
+/// set according to the following criteria:
+/// * For A operand, the data must be packed in minimum `packedDpasASizeInBits`
+/// * For B operand, the data must be packed in minimum `packedDpasBSizeInBits`
+static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
+ auto elementTy = vectorTy.getElementType();
+ assert(elementTy.isIntOrFloat() &&
+ "Expected int or float type in DPAS operands");
+ WiLayout layout({1, subgroupSize});
+ /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
+ /// must have the VNNI format.
+ if (operandNum == 1 &&
+ elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) {
+ WiData data(
+ {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1});
+ return SGMap(layout, data);
+ }
+ /// Otherwise, return the default layout for the vector type.
+ return getDefaultSgMap(vectorTy);
+}
+
///===----------------------------------------------------------------------===///
/// SGMapPropagation
///===----------------------------------------------------------------------===///
@@ -360,14 +378,14 @@ void SGMapPropagation::visitUpdateNdOffsetOp(
void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
- auto aTy = dpas.getLhsType().getElementType();
- auto bTy = dpas.getRhsType().getElementType();
+ auto aTy = dpas.getLhsType();
+ auto bTy = dpas.getRhsType();
propagateIfChanged(operands[0],
operands[0]->meet(getSGMapForDPASOperand(aTy, 0)));
propagateIfChanged(operands[1],
operands[1]->meet(getSGMapForDPASOperand(bTy, 1)));
if (operands.size() > 2) {
- auto cTy = dpas.getAccType().getElementType();
+ auto cTy = dpas.getAccType();
propagateIfChanged(operands[2],
operands[2]->meet(getSGMapForDPASOperand(cTy, 2)));
}
>From 562b5c7f7fc16dddc95ebb3ea914ab9314233e3c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 21:34:18 +0000
Subject: [PATCH 19/27] save work
---
.../mlir/Dialect/XeGPU/Transforms/Passes.td | 5 +++++
.../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 14 ++++----------
2 files changed, 9 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index cb9d403566645..3e81f2d0ed786 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -31,6 +31,11 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
let dependentDialects = [
"memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
];
+ let options = [
+ Option<"printOnly", "print-analysis-only", "bool",
+ /*default=*/"false",
+ "Print the result of the subgroup map propagation analysis and exit.">
+ ];
}
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d5e4f6532955e..6213a6b7d1a6e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -355,7 +355,6 @@ void SGMapPropagation::visitVectorMultiReductionOp(
/// Given that the result is 1D, the layout of the operand should be 2D with
/// default layout.
auto operandLayout = getDefaultSgMap(2);
- operandLayout.print(llvm::outs());
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
/// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -625,16 +624,11 @@ struct XeGPUSubgroupDistributePass final
: public xegpu::impl::XeGPUSubgroupDistributeBase<
XeGPUSubgroupDistributePass> {
XeGPUSubgroupDistributePass() = default;
- XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other)
- : xegpu::impl::XeGPUSubgroupDistributeBase<XeGPUSubgroupDistributePass>(
- other) {
- this->printOnly = other.printOnly;
- }
+ XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) =
+ default;
+ XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options)
+ : XeGPUSubgroupDistributeBase(options) {}
void runOnOperation() override;
- /// Print sg map propagation analysis result and exit for testing purposes.
- Option<bool> printOnly{*this, "print-analysis-only", llvm::cl::init(false),
- llvm::cl::desc("Print the result of the subgroup map "
- "propagation analysis and exit.")};
};
} // namespace
>From d8b74b34b3b9c781ee039ceb3adb6b05559d0049 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 11 Mar 2025 21:47:39 +0000
Subject: [PATCH 20/27] save work
---
.../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 6213a6b7d1a6e..d6f39887af8e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -179,7 +179,7 @@ struct SGMapLattice : public Lattice<SGMap> {
/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1].
/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1].
static SGMap getDefaultSgMap(unsigned rank) {
- assert((rank == 1 || rank == 2) && "Expected 0D or 1D vector.");
+ assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1)
return SGMap(WiLayout({subgroupSize}), WiData({1}));
return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1}));
@@ -428,11 +428,10 @@ void SGMapPropagation::visitTransposeOp(
vector::TransposeOp transpose, ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
/// Need the layout of transpose result to propagate to the operands.
- auto operandLayout = results[0]->getValue();
- if (!operandLayout.isAssigned())
+ auto resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
return;
- auto newLayout =
- operandLayout.getTransposedLayout(transpose.getPermutation());
+ auto newLayout = resultLayout.getTransposedLayout(transpose.getPermutation());
/// Propagate the new layout to the vector operand.
propagateIfChanged(operands[0], operands[0]->meet(newLayout));
}
>From 7196c9f5e5267fe7eaa9fc60092f4fe2544c303f Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 12 Mar 2025 19:42:08 +0000
Subject: [PATCH 21/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 93 ++++++++++---------
1 file changed, 50 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d6f39887af8e7..f8a27176ae66c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
@@ -26,9 +27,6 @@ namespace xegpu {
} // namespace xegpu
} // namespace mlir
-#define DEBUG_TYPE "xegpu-subgroup-distribute"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
using namespace mlir;
using namespace mlir::dataflow;
@@ -206,8 +204,10 @@ static SGMap getDefaultSgMap(VectorType vectorTy) {
/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is
/// set according to the following criteria:
-/// * For A operand, the data must be packed in minimum `packedDpasASizeInBits`
-/// * For B operand, the data must be packed in minimum `packedDpasBSizeInBits`
+/// * For A operand, the data must be packed in minimum
+/// `packedSizeInBitsForDefault`
+/// * For B operand, the data must be packed in minimum
+/// `packedSizeInBitsForDpasB`
static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) {
auto elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
@@ -299,40 +299,48 @@ LogicalResult
SGMapPropagation::visitOperation(Operation *op,
ArrayRef<SGMapLattice *> operands,
ArrayRef<const SGMapLattice *> results) {
- if (auto dpas = dyn_cast<xegpu::DpasOp>(op))
- visitDpasOp(dpas, operands, results);
- else if (auto store = dyn_cast<xegpu::StoreNdOp>(op))
- visitStoreNdOp(store, operands, results);
- else if (auto load = dyn_cast<xegpu::LoadNdOp>(op))
- visitLoadNdOp(load, operands, results);
- else if (auto transpose = dyn_cast<vector::TransposeOp>(op))
- visitTransposeOp(transpose, operands, results);
- else if (auto bitcast = dyn_cast<vector::BitCastOp>(op))
- visitVectorBitcastOp(bitcast, operands, results);
- else if (auto loadGather = dyn_cast<xegpu::LoadGatherOp>(op))
- visitLoadGatherOp(loadGather, operands, results);
- else if (auto createDesc = dyn_cast<xegpu::CreateDescOp>(op))
- visitCreateDescOp(createDesc, operands, results);
- else if (auto storeScatter = dyn_cast<xegpu::StoreScatterOp>(op))
- visitStoreScatterOp(storeScatter, operands, results);
- else if (auto updateNdOffset = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
- visitUpdateNdOffsetOp(updateNdOffset, operands, results);
- else if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op))
- visitVectorMultiReductionOp(reduction, operands, results);
- /// No need to propagate the layout to operands in CreateNdDescOp because they
- /// are scalars (offsets, sizes, etc.).
- else if (auto createNdDesc = dyn_cast<xegpu::CreateNdDescOp>(op))
- return success();
- /// All other ops
- else {
- for (const SGMapLattice *r : results) {
- for (SGMapLattice *operand : operands) {
- /// Propagate the layout of the result to the operand.
- if (r->getValue().isAssigned())
- meet(operand, *r);
- }
- }
- }
+ TypeSwitch<Operation *>(op)
+ .Case<xegpu::DpasOp>(
+ [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
+ .Case<xegpu::StoreNdOp>(
+ [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
+ .Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
+ visitStoreScatterOp(storeScatterOp, operands, results);
+ })
+ .Case<xegpu::LoadNdOp>(
+ [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
+ .Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
+ visitLoadGatherOp(loadGatherOp, operands, results);
+ })
+ .Case<xegpu::CreateDescOp>([&](auto createDescOp) {
+ visitCreateDescOp(createDescOp, operands, results);
+ })
+ .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
+ visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
+ })
+ /// No need to propagate the layout to operands in CreateNdDescOp because
+ /// they are scalars (offsets, sizes, etc.).
+ .Case<xegpu::CreateNdDescOp>(
+ [&](auto createNdDescOp) { return success(); })
+ .Case<vector::TransposeOp>([&](auto transposeOp) {
+ visitTransposeOp(transposeOp, operands, results);
+ })
+ .Case<vector::BitCastOp>([&](auto bitcastOp) {
+ visitVectorBitcastOp(bitcastOp, operands, results);
+ })
+ .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
+ visitVectorMultiReductionOp(reductionOp, operands, results);
+ })
+ /// All other ops.
+ .Default([&](Operation *op) {
+ for (const SGMapLattice *r : results) {
+ for (SGMapLattice *operand : operands) {
+ /// Propagate the layout of the result to the operand.
+ if (r->getValue().isAssigned())
+ meet(operand, *r);
+ }
+ }
+ });
/// Add a dependency from each reult to program point after the operation.
/// NOTE: not sure if this is required, but all other similar analysis do
/// this.
@@ -411,7 +419,7 @@ void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load,
if (!valueLayout.isAssigned())
return;
SGMap tensorDescLayout = valueLayout;
- /// LoadNdOp has the transpose effect. However, at the stage of this analyis
+ /// LoadNdOp has the transpose effect. However, at the stage of this analysis
/// this effect is not expected and should be abstracted away. Emit a warning.
if (auto transpose = load.getTranspose()) {
load.emitWarning("Transpose effect is not expected for LoadNdOp at "
@@ -529,7 +537,7 @@ void SGMapPropagation::visitStoreScatterOp(
return;
}
auto valueLayout = getDefaultSgMap(storeScatter.getValueType());
- SGMap storeScatterLayout;
+ SGMap storeScatterLayout = valueLayout;
if (storeScatter.getTranspose()) {
/// StoreScatteOp allows transpose effect. However, at the stage of this
/// analyis this effect is not expected and should be abstracted away. Emit
@@ -537,8 +545,7 @@ void SGMapPropagation::visitStoreScatterOp(
storeScatter.emitWarning("Transpose effect is not expected for "
"StoreScatterOp at SGMapPropagation stage.");
storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
- } else
- storeScatterLayout = valueLayout;
+ }
/// Propagate the value layout.
propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
/// Propagate the tensor descriptor layout.
>From 4c9e641af62f8be956afdb2deacd548f25f1704c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 12 Mar 2025 20:02:34 +0000
Subject: [PATCH 22/27] save work
---
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h | 2 --
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 3 ---
2 files changed, 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 86b95721df60c..63ea26df06937 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -16,8 +16,6 @@ namespace xegpu {
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
-/// Appends patterns for distributing XeGPU ops to work items into `patterns`.
-void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index f8a27176ae66c..13688e9cba62b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -649,6 +649,3 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return;
}
}
-
-void xegpu::populateXeGPUSubgroupDistributePatterns(
- RewritePatternSet &patterns) {}
>From 861f9075bfa18410c4d6e39cc822f85bb6fad13c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 00:17:07 +0000
Subject: [PATCH 23/27] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 78 +++++++++++--------
1 file changed, 47 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 13688e9cba62b..1fdd37f031385 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -10,12 +10,14 @@
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
@@ -490,7 +492,7 @@ void SGMapPropagation::visitLoadGatherOp(
if (!valueLayout.isAssigned())
return;
- SGMap tensorDescLayout;
+ SGMap tensorDescLayout = valueLayout;
if (load.getTranspose()) {
/// LoadGatherOp has the transpose effect. However, at the stage of this
/// analyis this effect is not expected and should be abstracted away. Emit
@@ -498,8 +500,7 @@ void SGMapPropagation::visitLoadGatherOp(
load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
"SGMapPropagation stage.");
tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
- } else
- tensorDescLayout = valueLayout;
+ }
/// Mask operand should have 1D default layout.
auto maskLayout = getDefaultSgMap(1);
/// Propagate the new layout to the tensor descriptor operand.
@@ -590,38 +591,53 @@ SGMap RunSGMapPropagation::getSGMap(Value val) {
}
void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) {
- if (auto modOp = dyn_cast<ModuleOp>(target)) {
- for (auto funcOp : modOp.getOps<func::FuncOp>()) {
- os << "function: " << funcOp.getName() << ":\n";
- // Function arguments
- for (auto arg : funcOp.getArguments()) {
- auto layout = getSGMap(arg);
- os << "argument: " << arg << "\n";
- os << "sg_map : ";
+ auto printFunctionResult = [&](FunctionOpInterface funcOp) {
+ os << "function: " << funcOp.getName() << ":\n";
+ // Function arguments
+ for (auto arg : funcOp.getArguments()) {
+ auto layout = getSGMap(arg);
+ os << "argument: " << arg << "\n";
+ os << "sg_map : ";
+ layout.print(os);
+ os << "\n";
+ }
+ // Function ops
+ funcOp.walk([&](Operation *op) {
+ // Skip ops that do not have results
+ if (op->getResults().empty())
+ return;
+ os << "op : ";
+ /// For control-flow ops, print the op name only.
+ if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
+ os << op->getName();
+ else
+ op->print(os);
+ os << "\n";
+ /// Print the sg_map for each result.
+ for (auto [i, r] : llvm::enumerate(op->getResults())) {
+ auto layout = getSGMap(r);
+ os << "sg_map for result #" << i << ": ";
layout.print(os);
os << "\n";
}
- // Function ops
- funcOp.walk([&](Operation *op) {
- // Skip ops that do not have results
- if (op->getResults().empty())
- return;
- os << "op : ";
- /// For control-flow ops, print the op name only.
- if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
- os << op->getName();
- else
- op->print(os);
- os << "\n";
- /// Print the sg_map for each result.
- for (auto [i, r] : llvm::enumerate(op->getResults())) {
- auto layout = getSGMap(r);
- os << "sg_map for result #" << i << ": ";
- layout.print(os);
- os << "\n";
- }
- });
+ });
+ };
+
+ SmallVector<FunctionOpInterface> funcOps;
+ if (auto modOp = dyn_cast<ModuleOp>(target)) {
+ for (auto funcOp : modOp.getOps<FunctionOpInterface>()) {
+ funcOps.push_back(funcOp);
}
+ /// Collect all GpuFuncOps in the module.
+ for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
+ for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>()) {
+ funcOps.push_back(gpuFuncOp);
+ }
+ }
+ }
+ /// Print the analysis result for each function.
+ for (auto funcOp : funcOps) {
+ printFunctionResult(funcOp);
}
}
>From ac337a3659df75a581367d0f139afa0c85b75df8 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 01:53:08 +0000
Subject: [PATCH 24/27] save work
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 1fdd37f031385..124e65b57521d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -9,7 +9,6 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -18,7 +17,6 @@
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
>From 7ac804c6e99c2cf54ade6618a5b09d958b1fc1fb Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 16:38:07 +0000
Subject: [PATCH 25/27] save work
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 124e65b57521d..2de68695fac2b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -31,7 +31,7 @@ using namespace mlir;
using namespace mlir::dataflow;
/// HW dependent constants.
-/// TODO: These constants should be queried from the uArch interface.
+/// TODO: These constants should be queried from the target information.
constexpr unsigned subgroupSize = 16; // How many work items in a subgroup.
/// If DPAS A or B operands have low precision element types they must be packed
/// according to the following sizes.
@@ -342,8 +342,6 @@ SGMapPropagation::visitOperation(Operation *op,
}
});
/// Add a dependency from each reult to program point after the operation.
- /// NOTE: not sure if this is required, but all other similar analysis do
- /// this.
for (const SGMapLattice *r : results) {
addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
}
>From 0ce71f87ce139e3380da4beef6a8a75503c30f8a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 16:38:45 +0000
Subject: [PATCH 26/27] save work
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2de68695fac2b..74fcb6b8c7613 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -341,7 +341,7 @@ SGMapPropagation::visitOperation(Operation *op,
}
}
});
- /// Add a dependency from each reult to program point after the operation.
+ /// Add a dependency from each result to program point after the operation.
for (const SGMapLattice *r : results) {
addDependency(const_cast<SGMapLattice *>(r), getProgramPointAfter(op));
}
>From 4949a1f286348fd023fcc40da3b9a7c80986e5a9 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 14 Mar 2025 16:46:44 +0000
Subject: [PATCH 27/27] save work
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 74fcb6b8c7613..86e07697f437c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -320,8 +320,7 @@ SGMapPropagation::visitOperation(Operation *op,
})
/// No need to propagate the layout to operands in CreateNdDescOp because
/// they are scalars (offsets, sizes, etc.).
- .Case<xegpu::CreateNdDescOp>(
- [&](auto createNdDescOp) { return success(); })
+ .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
.Case<vector::TransposeOp>([&](auto transposeOp) {
visitTransposeOp(transposeOp, operands, results);
})
More information about the Mlir-commits
mailing list