[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for XeGPU (1/N) (PR #137010)
Chao Chen
llvmlistbot at llvm.org
Thu May 8 12:05:24 PDT 2025
================
@@ -0,0 +1,96 @@
+//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+
+namespace {
+
+struct TestXeGPUUnrollingPatterns
+ : public PassWrapper<TestXeGPUUnrollingPatterns,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns)
+
+ StringRef getArgument() const final {
+ return "test-xegpu-unrolling-patterns";
+ }
+
+ StringRef getDescription() const final {
+ return "Test lowering patterns to unroll ops in the xegpu dialect";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
+ registry.insert<memref::MemRefDialect>();
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<vector::VectorDialect>();
+ }
+
+ TestXeGPUUnrollingPatterns() = default;
+ TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass)
+ : PassWrapper(pass) {}
+
+ void runOnOperation() override {
+ xegpu::UnrollOptions options;
+ options.setNativeShapeFn(
+ [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+ tdescTy = createNdOp.getType();
+ } else if (auto updateNdOp =
+ dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+ tdescTy = updateNdOp.getTensorDescType();
+ } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
+ tdescTy = prefetchNdOp.getTensorDescType();
+ } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+ tdescTy = loadNdOp.getTensorDescType();
+ } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+ tdescTy = storeNdOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isSgLayout())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_data.asArrayRef().end());
+ }
+ }
+
+ if (isa<xegpu::DpasOp>(op)) {
----------------
chencha3 wrote:
thanks! fixed.
https://github.com/llvm/llvm-project/pull/137010
More information about the Mlir-commits
mailing list