[Mlir-commits] [mlir] f4c0c40 - [mlir][xegpu] XeGPU alias ops folder pass (#88886)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 19 07:41:40 PDT 2024
Author: Adam Siemieniuk
Date: 2024-04-19T09:41:37-05:00
New Revision: f4c0c40f388fff0975ecada4997683cef3cb1fae
URL: https://github.com/llvm/llvm-project/commit/f4c0c40f388fff0975ecada4997683cef3cb1fae
DIFF: https://github.com/llvm/llvm-project/commit/f4c0c40f388fff0975ecada4997683cef3cb1fae.diff
LOG: [mlir][xegpu] XeGPU alias ops folder pass (#88886)
Adds a pass that folds aliasing ops into XeGPU ops.
Added:
mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.h
mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
Modified:
mlir/docs/Passes.md
mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
mlir/include/mlir/InitAllPasses.h
mlir/lib/Dialect/XeGPU/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md
index 84e6664436d7b3..6a18e06593e8e9 100644
--- a/mlir/docs/Passes.md
+++ b/mlir/docs/Passes.md
@@ -119,3 +119,7 @@ This document describes the available MLIR passes and their contracts.
## TOSA Dialect Passes
[include "TosaPasses.md"]
+
+## XeGPU Dialect Passes
+
+[include "XeGPUPasses.md"]
diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
index f33061b2d87cff..9f57627c321fb0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..9de7e87c7d3995
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name XeGPU)
+add_public_tablegen_target(MLIRXeGPUPassIncGen)
+add_dependencies(mlir-headers MLIRXeGPUPassIncGen)
+
+add_mlir_doc(Passes XeGPUPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.h
new file mode 100644
index 00000000000000..bf55bde4b25b1f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.h
@@ -0,0 +1,35 @@
+//===- Passes.h - XeGPU Patterns and Passes ---------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace xegpu {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
new file mode 100644
index 00000000000000..1ecd6ce95322bd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -0,0 +1,26 @@
+//===-- Passes.td - XeGPU transformation definition file ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
+ let summary = "Fold alias ops into XeGPU ops";
+ let description = [{
+ The pass folds aliasing ops into XeGPU ops that they operate on the original
+ source references.
+ }];
+ let dependentDialects = [
+ "memref::MemRefDialect", "xegpu::XeGPUDialect"
+ ];
+}
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
new file mode 100644
index 00000000000000..63ea26df069372
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -0,0 +1,23 @@
+//===- Transforms.h - XeGPU Dialect transformations -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace xegpu {
+
+/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
+void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5d90c197a6cced..90406f555b0f47 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -45,6 +45,7 @@
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Transforms/Passes.h"
#include <cstdlib>
@@ -92,6 +93,7 @@ inline void registerAllPasses() {
arm_sme::registerArmSMEPasses();
arm_sve::registerArmSVEPasses();
emitc::registerEmitCPasses();
+ xegpu::registerXeGPUPasses();
// Dialect pipelines
bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index f33061b2d87cff..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..7fb64d3b97b87d
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransforms
+ XeGPUFoldAliasOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
+
+ DEPENDS
+ MLIRXeGPUPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAffineUtils
+ MLIRIR
+ MLIRMemRefDialect
+ MLIRXeGPUDialect
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
new file mode 100644
index 00000000000000..9307e8eb784b54
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp
@@ -0,0 +1,82 @@
+//===- XeGPUFoldAliasOps.cpp - XeGPU alias ops folders ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUFOLDALIASOPS
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-fold-alias-ops"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+/// Merges subview operation with xegpu.create_nd_tdesc operation.
+class XegpuCreateNdDescOpSubViewOpFolder final
+ : public OpRewritePattern<xegpu::CreateNdDescOp> {
+public:
+ using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp,
+ PatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
+ xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const {
+ auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>();
+
+ if (!subViewOp)
+ return rewriter.notifyMatchFailure(descOp, "not a subview producer");
+ if (!subViewOp.hasUnitStride())
+ return rewriter.notifyMatchFailure(descOp, "requires unit strides");
+
+ SmallVector<Value> resolvedOffsets;
+ affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(),
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
+ descOp.getMixedOffsets(), resolvedOffsets);
+
+ rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+ descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
+ getAsOpFoldResult(resolvedOffsets));
+
+ return success();
+}
+
+void xegpu::populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns) {
+ patterns.add<XegpuCreateNdDescOpSubViewOpFolder>(patterns.getContext());
+}
+
+namespace {
+
+struct XeGPUFoldAliasOpsPass final
+ : public xegpu::impl::XeGPUFoldAliasOpsBase<XeGPUFoldAliasOpsPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUFoldAliasOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ xegpu::populateXeGPUFoldAliasOpsPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
new file mode 100644
index 00000000000000..d32954127fce61
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-fold-alias-ops.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -xegpu-fold-alias-ops -split-input-file %s | FileCheck %s
+
+func.func @fold_subview_with_xegpu_create_nd_tdesc(%arg0 : memref<256x256xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) ->(!xegpu.tensor_desc<8x16xf32>) {
+ %subview = memref.subview %arg0[%arg1, %arg2] [32, 32] [1, 1] :
+ memref<256x256xf32> to memref<32x32xf32, strided<[256, 1], offset: ?>>
+ %0 = xegpu.create_nd_tdesc %subview[%arg3, %arg4] :
+ memref<32x32xf32, strided<[256, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+ return %0 : !xegpu.tensor_desc<8x16xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @fold_subview_with_xegpu_create_nd_tdesc
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<256x256xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]]]
+// CHECK: xegpu.create_nd_tdesc %[[ARG0]][%[[IDX0]], %[[IDX1]]] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32>
More information about the Mlir-commits
mailing list