[Mlir-commits] [mlir] Introduce mlir `AMX` dialect extension (PR #107528)
Haixin Huang
llvmlistbot at llvm.org
Fri Sep 6 00:05:21 PDT 2024
https://github.com/huanghaixin008 created https://github.com/llvm/llvm-project/pull/107528
None
>From 454884ba8df56714bf45b122feab3ce180b27707 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 14 Jun 2024 01:57:18 -0700
Subject: [PATCH 01/17] add AMX Ops extension for tmm register binding
---
mlir/include/mlir/Dialect/AMX/AMX.td | 139 ++++++++++++++++--
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 7 +-
2 files changed, 132 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac..1eeb5c074ca2e6 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -64,16 +64,36 @@ def AMX_Dialect : Dialect {
class AMX_Op<string mnemonic, list<Trait> traits = []> :
Op<AMX_Dialect, mnemonic, traits> {}
+class AMX_IntrOpBase<string mnemonic, int numResults,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<Trait> traits = []>
+ : LLVM_IntrOpBase<
+ /*Dialect dialect=*/AMX_Dialect,
+ /*string opName=*/mnemonic,
+ /*string enumName=*/"x86_" # !subst(".", "_", mnemonic),
+ /*list<int> overloadedResults=*/[],
+ /*list<int> overloadedOperands=*/[],
+ /*list<Trait> traits=*/traits,
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
+
// The "internal" intrinsics are meant for compiler usage.
class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
- LLVM_IntrOpBase<AMX_Dialect, mnemonic,
- "x86_" # !subst(".", "_", mnemonic) # "_internal",
- [], [], traits, numResults>;
+ AMX_IntrOp<mnemonic # "_internal", numResults, [], [], traits>;
//===----------------------------------------------------------------------===//
// AMX Op definitions (user facing).
//===----------------------------------------------------------------------===//
+
+def TileRegisterIndexAttr : OptionalAttr<
+ ConfinedAttr<SI8Attr, [IntMinValue<0>, IntMaxValue<7>]>>;
+
//
// Tile reset.
//
@@ -91,6 +111,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
%0 = amx.tile_zero : vector<16x16xbf16>
```
}];
+ let arguments = (ins TileRegisterIndexAttr:$dstRegIndex);
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
@@ -98,7 +119,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
- let assemblyFormat = "attr-dict `:` type($res)";
+ let assemblyFormat = "(`[` $dstRegIndex^ `]`)? attr-dict `:` type($res)";
let hasVerifier = 1;
}
@@ -121,7 +142,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
- Variadic<Index>:$indices);
+ Variadic<Index>:$indices, TileRegisterIndexAttr:$srcRegIndex);
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
@@ -132,7 +153,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
- let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
+ let assemblyFormat = "$base `[` $indices `]` attr-dict (`into` `[` $srcRegIndex^ `]`)? `:` "
"type($base) `into` type($res)";
let hasVerifier = 1;
}
@@ -153,7 +174,8 @@ def TileStoreOp : AMX_Op<"tile_store"> {
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+ VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val,
+ TileRegisterIndexAttr:$dstRegIndex);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -162,7 +184,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
return ::llvm::cast<VectorType>(getVal().getType());
}
}];
- let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
+ let assemblyFormat = "$base `[` $indices `]` `,` $val (`[` $dstRegIndex^ `]`)? attr-dict `:` "
"type($base) `,` type($val)";
let hasVerifier = 1;
}
@@ -189,7 +211,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
}];
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
- VectorOfRankAndType<[2], [F32, BF16]>:$acc);
+ VectorOfRankAndType<[2], [F32, BF16]>:$acc,
+ TileRegisterIndexAttr:$lhsRegIndex,
+ TileRegisterIndexAttr:$rhsRegIndex,
+ TileRegisterIndexAttr:$accRegIndex);
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
@@ -202,7 +227,9 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
- let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
+ let assemblyFormat = "$lhs (`[` $lhsRegIndex^ `]`)? `,` "
+ "$rhs (`[` $rhsRegIndex^ `]`)? `,` "
+ "$acc (`[` $accRegIndex^ `]`)? attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
let hasVerifier = 1;
}
@@ -230,7 +257,10 @@ def TileMulIOp : AMX_Op<"tile_muli", [
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
UnitAttr:$isZextLhs,
- UnitAttr:$isZextRhs
+ UnitAttr:$isZextRhs,
+ TileRegisterIndexAttr:$lhsRegIndex,
+ TileRegisterIndexAttr:$rhsRegIndex,
+ TileRegisterIndexAttr:$accRegIndex
);
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
let extraClassDeclaration = [{
@@ -244,13 +274,15 @@ def TileMulIOp : AMX_Op<"tile_muli", [
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
- let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
+ let assemblyFormat = "$lhs (`zext` $isZextLhs^)? (`[` $lhsRegIndex^ `]`)? `,` "
+ "$rhs (`zext` $isZextRhs^)? (`[` $rhsRegIndex^ `]`)? `,` "
+ "$acc (`[` $accRegIndex^ `]`)? attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
-// AMX IntrOp definitions (LLVM compiler facing).
+// AMX Internal IntrOp definitions (LLVM compiler facing).
//===----------------------------------------------------------------------===//
//
@@ -310,4 +342,85 @@ def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+//===----------------------------------------------------------------------===//
+// AMX IntrOp definitions (direct mapping to physical assembly).
+// The `dm` suffixes in Op names stand for `direct mapping`.
+//===----------------------------------------------------------------------===//
+
+//
+// Tile palette config operations. Parameter define the pointer to palette membuf.
+//
+
+def LLVM_x86_amx_ldtilecfg_plain : AMX_IntrOpBase<"ldtilecfg", 0>,
+ Arguments<(int LLVM_AnyPointer)>;
+
+def LLVM_x86_amx_tilerelease_plain : AMX_IntrOpBase<"tilerelease", 0>;
+
+//
+// Tile reset. Parameters define the tile size.
+//
+
+def LLVM_x86_amx_tilezero_plain : AMX_IntrOpBase<"tilezero", 0, [0], ["res_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of resulting tmm registers">:$res_index)>;
+
+//
+// Tile memory operations. Parameters define the tile size,
+// base address, and stride between consecutive rows for the
+// memory operation.
+//
+
+def LLVM_x86_amx_tileloadd64_plain : AMX_IntrOpBase<"tileloadd64", 0, [0], ["dst_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+ LLVM_AnyPointer, AnyInteger)>;
+
+// Non-temporal load version
+def LLVM_x86_amx_tileloaddt164_plain : AMX_IntrOpBase<"tileloaddt164", 0, [0], ["dst_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+ LLVM_AnyPointer, AnyInteger)>;
+
+def LLVM_x86_amx_tilestored64_plain : AMX_IntrOpBase<"tilestored64", 0, [0], ["src_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of src tmm registers">:$src_index,
+ LLVM_AnyPointer, AnyInteger)>;
+
+//
+// Tile multiplication operations (series of dot products). Parameters
+// define the tile sizes and source and destination tiles for the
+// operation. Note that the prefix "tdp" stands for tile dot product.
+//
+
+// Dot product of bf16 tiles into f32 tile.
+def LLVM_x86_amx_tdpbf16ps_plain : AMX_IntrOpBase<"tdpbf16ps", 0,
+ [0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+ Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+ Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/sign extension).
+def LLVM_x86_amx_tdpbssd_plain : AMX_IntrOpBase<"tdpbssd", 0,
+ [0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+ Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+ Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/zero extension).
+def LLVM_x86_amx_tdpbsud_plain : AMX_IntrOpBase<"tdpbsud", 1,
+ [0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+ Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+ Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/sign extension).
+def LLVM_x86_amx_tdpbusd_plain : AMX_IntrOpBase<"tdpbusd", 1,
+ [0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+ Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+ Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/zero extension).
+def LLVM_x86_amx_tdpbuud_plain : AMX_IntrOpBase<"tdpbuud", 1,
+ [0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+ Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+ Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+ Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
#endif // AMX
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a8b10f63315d41..bfa3bbc83c80a5 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -211,7 +211,12 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
- x86_amx_tdpbusd, x86_amx_tdpbuud>();
+ x86_amx_tdpbusd, x86_amx_tdpbuud, x86_amx_ldtilecfg_plain,
+ x86_amx_tilerelease_plain, x86_amx_tilezero_plain,
+ x86_amx_tileloadd64_plain, x86_amx_tilestored64_plain,
+ x86_amx_tilestoreddt164_plain x86_amx_tdpbf16ps_plain,
+ x86_amx_tdpbssd_plain, x86_amx_tdpbsud_plain,
+ x86_amx_tdpbusd_plain, x86_amx_tdpbuud_plain>();
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
TileMulFOp>();
}
>From 9b0e23a8c3e4b64d9096cf3c851e46d8b4c72575 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Sun, 16 Jun 2024 22:57:12 -0700
Subject: [PATCH 02/17] add basic framework
---
mlir/include/mlir/Dialect/AMX/AMX.td | 6 +-
mlir/include/mlir/Dialect/AMX/CMakeLists.txt | 7 ++
mlir/include/mlir/Dialect/AMX/Passes.h | 37 ++++++++++
mlir/include/mlir/Dialect/AMX/Passes.td | 25 +++++++
.../lib/Dialect/AMX/Transforms/CMakeLists.txt | 1 +
.../AMX/Transforms/EnableAMXTileBinding.cpp | 68 +++++++++++++++++++
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 4 +-
7 files changed, 143 insertions(+), 5 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/AMX/Passes.h
create mode 100644 mlir/include/mlir/Dialect/AMX/Passes.td
create mode 100644 mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 1eeb5c074ca2e6..f5f242d687a0ab 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -84,7 +84,7 @@ class AMX_IntrOpBase<string mnemonic, int numResults,
// The "internal" intrinsics are meant for compiler usage.
class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
- AMX_IntrOp<mnemonic # "_internal", numResults, [], [], traits>;
+ AMX_IntrOpBase<mnemonic # "_internal", numResults, [], [], traits>;
//===----------------------------------------------------------------------===//
// AMX Op definitions (user facing).
@@ -344,7 +344,7 @@ def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
//===----------------------------------------------------------------------===//
// AMX IntrOp definitions (direct mapping to physical assembly).
-// The `dm` suffixes in Op names stand for `direct mapping`.
+// The `plain` suffixes in Op names stand for plain direct mapping.
//===----------------------------------------------------------------------===//
//
@@ -352,7 +352,7 @@ def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
//
def LLVM_x86_amx_ldtilecfg_plain : AMX_IntrOpBase<"ldtilecfg", 0>,
- Arguments<(int LLVM_AnyPointer)>;
+ Arguments<(ins LLVM_AnyPointer)>;
def LLVM_x86_amx_tilerelease_plain : AMX_IntrOpBase<"tilerelease", 0>;
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f3f1aff5a63609..d7a494caa22192 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
set(LLVM_TARGET_DEFINITIONS AMX.td)
mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRAMXConversionsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name AMX)
+add_public_tablegen_target(MLIRAMXTransformsIncGen)
+add_dependencies(mlir-headers MLIRAMXTransformsIncGen)
+
+add_mlir_doc(Passes AMXPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/AMX/Passes.h b/mlir/include/mlir/Dialect/AMX/Passes.h
new file mode 100644
index 00000000000000..9a3fd458eab380
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Passes.h
@@ -0,0 +1,37 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_PASSES_H
+#define MLIR_DIALECT_AMX_PASSES_H
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+class RewritePatternSet;
+
+namespace amx {
+//===----------------------------------------------------------------------===//
+// The EnableAMXTileBinding pass.
+//===----------------------------------------------------------------------===//
+#define GEN_PASS_DECL
+#include "mlir/Dialect/AMX/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/AMX/Passes.h.inc"
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/AMX/Passes.td b/mlir/include/mlir/Dialect/AMX/Passes.td
new file mode 100644
index 00000000000000..f3c7800b7d8730
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Passes.td
@@ -0,0 +1,25 @@
+//===-- Passes.td - AMX pass definition file ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_PASSES_TD
+#define MLIR_DIALECT_AMX_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+def EnableAMXTileBinding
+ : Pass<"enable-amx-tile-binding", "mlir::func::FuncOp"> {
+ let summary = "Enable AMX tile register binding";
+ let description = [{
+ Enables the AMX tile register binding for each func
+ by propagating specified binding information
+ and automatically configuring harware context
+ }];
+ let dependentDialects = ["func::FuncDialect"];
+}
+
+#endif // MLIR_DIALECT_AMX_PASSES_TD
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index 29340d4f45dd1f..a0c1c10f3c1701 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRAMXTransforms
LegalizeForLLVMExport.cpp
+ EnableAMXTileBinding.cpp
DEPENDS
MLIRAMXConversionsIncGen
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
new file mode 100644
index 00000000000000..74e25ca4550422
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -0,0 +1,68 @@
+//===- EnableAMXTileBinding.cpp - Enable tile binding for Intel AMX -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass enables the tile register binding semantic for IntelĀ® Advanced
+// Matrix Extensions (IntelĀ® AMX). Intuitively, this pass analyses the tile
+// binding hints set by users, legalize the hints and automatically configures
+// needed hardware context. The AMX tile register usage in lowered intrinsics
+// would strictly respect the given hints, enforced in lowering pass
+// `--convert-vector-to-llvm`.
+//
+// Note that if this pass is not invoked prior to `--convert-vector-to-llvm`,
+// the AMX lowering would ignore the binding info and fallback to original
+// scheme.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+#define DEBUG_TYPE "enable-amx-tile-binding"
+
+namespace mlir {
+namespace amx {
+
+#define GEN_PASS_DEF_ENABLEAMXTILEBINDING
+#include "mlir/Dialect/AMX/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Analysis
+//===----------------------------------------------------------------------===//
+
+/// A class for analyzing tile register binding for each tile vector.
+class TileBindingAnalysis {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
+ explicit TileBindingAnalysis(Operation *);
+};
+
+TileBindingAnalysis::TileBindingAnalysis(Operation *root) {}
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+struct EnableAMXTileBindingPass
+ : public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
+ void runOnOperation() override {
+ // 0. Get AnalyseInfo for each concerned Value (mixed used of tmul & normal
+ // vector operations?)
+ TileBindingAnalysis &analysis = getAnalysis<TileBindingAnalysis>();
+
+ // 1. Propagate binding info to AMX Ops
+ //
+ // 2. Analyse tile scopes & expand them maximally
+ //
+ // 3. insert tile config/release according to tile scopes
+ }
+};
+
+} // namespace amx
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index bfa3bbc83c80a5..5e2aa216cf4123 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -213,8 +213,8 @@ void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
x86_amx_tdpbusd, x86_amx_tdpbuud, x86_amx_ldtilecfg_plain,
x86_amx_tilerelease_plain, x86_amx_tilezero_plain,
- x86_amx_tileloadd64_plain, x86_amx_tilestored64_plain,
- x86_amx_tilestoreddt164_plain x86_amx_tdpbf16ps_plain,
+ x86_amx_tileloadd64_plain, x86_amx_tileloaddt164_plain,
+ x86_amx_tilestored64_plain, x86_amx_tdpbf16ps_plain,
x86_amx_tdpbssd_plain, x86_amx_tdpbsud_plain,
x86_amx_tdpbusd_plain, x86_amx_tdpbuud_plain>();
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
>From 4554343bae6571d91ecd67c34c780714a8623f3a Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Mon, 24 Jun 2024 02:09:23 -0700
Subject: [PATCH 03/17] add binding info propagation
---
mlir/include/mlir/Dialect/AMX/AMX.td | 8 +-
.../AMX/Transforms/EnableAMXTileBinding.cpp | 244 +++++++++++++++++-
2 files changed, 242 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index f5f242d687a0ab..0c77b8b6ad90fc 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -142,7 +142,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
- Variadic<Index>:$indices, TileRegisterIndexAttr:$srcRegIndex);
+ Variadic<Index>:$indices, TileRegisterIndexAttr:$dstRegIndex);
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
@@ -153,7 +153,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
return ::llvm::cast<VectorType>(getRes().getType());
}
}];
- let assemblyFormat = "$base `[` $indices `]` attr-dict (`into` `[` $srcRegIndex^ `]`)? `:` "
+ let assemblyFormat = "$base `[` $indices `]` attr-dict (`into` `[` $dstRegIndex^ `]`)? `:` "
"type($base) `into` type($res)";
let hasVerifier = 1;
}
@@ -175,7 +175,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val,
- TileRegisterIndexAttr:$dstRegIndex);
+ TileRegisterIndexAttr:$srcRegIndex);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -184,7 +184,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
return ::llvm::cast<VectorType>(getVal().getType());
}
}];
- let assemblyFormat = "$base `[` $indices `]` `,` $val (`[` $dstRegIndex^ `]`)? attr-dict `:` "
+ let assemblyFormat = "$base `[` $indices `]` `,` $val (`[` $srcRegIndex^ `]`)? attr-dict `:` "
"type($base) `,` type($val)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 74e25ca4550422..0308002c27c908 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -21,6 +21,10 @@
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -36,28 +40,256 @@ namespace amx {
// Analysis
//===----------------------------------------------------------------------===//
-/// A class for analyzing tile register binding for each tile vector.
+/// A class for analyzing (propagating) tile register binding for each tile
+/// vector.
class TileBindingAnalysis {
+private:
+ bool isValidAnalysis;
+ DenseMap<Value, int> bindings;
+
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
explicit TileBindingAnalysis(Operation *);
+ bool isValid() const { return isValidAnalysis; }
+ void setValid(bool v) { isValidAnalysis = v; }
+ int getBinding(Value val) const {
+ auto iter = bindings.find(val);
+ if (iter == bindings.end())
+ return -1;
+ return iter->second;
+ }
+ void setBinding(Value val, int index) { bindings[val] = index; }
};
-TileBindingAnalysis::TileBindingAnalysis(Operation *root) {}
+static bool isTileOp(Operation *op) {
+ return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+ llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+ llvm::isa<TileStoreOp>(op);
+}
+
+template <typename Op>
+static bool TileMulCheck(Operation *op) {
+ auto tile_mul = dyn_cast_or_null<Op>(op);
+ assert(tile_mul);
+
+ auto lhsOp = tile_mul.getLhs().getDefiningOp();
+ auto rhsOp = tile_mul.getRhs().getDefiningOp();
+ auto accOp = tile_mul.getAcc().getDefiningOp();
+ if (!isTileOp(lhsOp) || !isTileOp(rhsOp) || !isTileOp(accOp))
+ return false;
+ return true;
+}
+
+// Not allow mixed use of tile Ops and normal vector Ops, any mixing is
+// considered unacceptable
+static bool isAcceptableTileOp(Operation *op) {
+ if (!isTileOp(op))
+ return false;
+
+ if (llvm::isa<TileMulFOp>(op)) {
+ return TileMulCheck<TileMulFOp>(op);
+ } else if (llvm::isa<TileMulIOp>(op)) {
+ return TileMulCheck<TileMulIOp>(op);
+ } else if (auto tileStore = dyn_cast_or_null<TileStoreOp>(op)) {
+ auto valOp = tileStore.getVal().getDefiningOp();
+ if (!isTileOp(valOp))
+ return false;
+ }
+ return true;
+}
+
+template <typename Op>
+static bool TileDstPropagate(TileBindingAnalysis *analysis, Operation *op) {
+ auto tileDst = dyn_cast_or_null<Op>(op);
+ assert(tileDst);
+ std::optional<int8_t> tmmIndex = tileDst.getDstRegIndex();
+ if (!tmmIndex) {
+ return false;
+ }
+ analysis->setBinding(tileDst.getRes(), *tmmIndex);
+ return true;
+}
+
+template <typename Op>
+static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
+ auto tileMul = dyn_cast_or_null<Op>(op);
+ assert(tileMul);
+ auto accVal = tileMul.getAcc();
+ auto accIndex = analysis->getBinding(accVal);
+ if (accIndex < 0)
+ return false;
+
+ analysis->setBinding(tileMul.getRes(), accIndex);
+ return true;
+}
+
+TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
+ isValidAnalysis = false;
+ func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+ if (!func)
+ return;
+
+ isValidAnalysis = true;
+ func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+ if (!isValidAnalysis)
+ return;
+ if (!isTileOp(op))
+ return;
+ if (!isAcceptableTileOp(op)) {
+ isValidAnalysis = false;
+ return;
+ }
+
+ if (llvm::isa<TileZeroOp>(op)) {
+ if (!TileDstPropagate<TileZeroOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ } else if (llvm::isa<TileLoadOp>(op)) {
+ if (!TileDstPropagate<TileLoadOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ } else if (llvm::isa<TileMulFOp>(op)) {
+ if (!TileMulPropagate<TileMulFOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ } else if (llvm::isa<TileMulIOp>(op)) {
+ if (!TileMulPropagate<TileMulIOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ }
+ });
+}
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
+class TileStoreBindingRewriter : public OpRewritePattern<TileStoreOp> {
+private:
+ TileBindingAnalysis &analysis;
+
+public:
+ using OpRewritePattern<TileStoreOp>::OpRewritePattern;
+
+ TileStoreBindingRewriter(MLIRContext *context, TileBindingAnalysis &ana)
+ : OpRewritePattern(context), analysis{ana} {}
+
+ LogicalResult matchAndRewrite(TileStoreOp op,
+ PatternRewriter &rewriter) const final {
+ auto val = op.getVal();
+ auto srcIndex = analysis.getBinding(val);
+ if (srcIndex < 0)
+ return failure();
+ auto existingAccIndex = op.getSrcRegIndex();
+ if (existingAccIndex && *existingAccIndex != srcIndex)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<TileStoreOp>(
+ op, op.getBase(), op.getIndices(), val,
+ rewriter.getI8IntegerAttr(srcIndex));
+ return success();
+ }
+};
+
+class TileMulFBindingRewriter : public OpRewritePattern<TileMulFOp> {
+private:
+ TileBindingAnalysis &analysis;
+
+public:
+ using OpRewritePattern<TileMulFOp>::OpRewritePattern;
+
+ TileMulFBindingRewriter(MLIRContext *context, TileBindingAnalysis &ana)
+ : OpRewritePattern(context), analysis{ana} {}
+
+ LogicalResult matchAndRewrite(TileMulFOp op,
+ PatternRewriter &rewriter) const final {
+ auto lhsVal = op.getLhs();
+ auto rhsVal = op.getRhs();
+ auto accVal = op.getAcc();
+ auto lhsIndex = analysis.getBinding(lhsVal);
+ auto rhsIndex = analysis.getBinding(rhsVal);
+ auto accIndex = analysis.getBinding(accVal);
+ if (lhsIndex < 0 || rhsIndex < 0 || accIndex < 0)
+ return failure();
+ auto existingLhsIndex = op.getLhsRegIndex();
+ auto existingRhsIndex = op.getRhsRegIndex();
+ auto existingAccIndex = op.getAccRegIndex();
+ if ((existingLhsIndex && *existingLhsIndex != lhsIndex) ||
+ (existingRhsIndex && *existingRhsIndex != rhsIndex) ||
+ (existingAccIndex && *existingAccIndex != accIndex))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<TileMulFOp>(
+ op, op.getRes().getType(), lhsVal, rhsVal, accVal,
+ rewriter.getI8IntegerAttr(lhsIndex),
+ rewriter.getI8IntegerAttr(rhsIndex),
+ rewriter.getI8IntegerAttr(accIndex));
+ return success();
+ }
+};
+
+class TileMulIBindingRewriter : public OpRewritePattern<TileMulIOp> {
+private:
+ TileBindingAnalysis &analysis;
+
+public:
+ using OpRewritePattern<TileMulIOp>::OpRewritePattern;
+
+ TileMulIBindingRewriter(MLIRContext *context, TileBindingAnalysis &ana)
+ : OpRewritePattern(context), analysis{ana} {}
+
+ LogicalResult matchAndRewrite(TileMulIOp op,
+ PatternRewriter &rewriter) const final {
+ auto lhsVal = op.getLhs();
+ auto rhsVal = op.getRhs();
+ auto accVal = op.getAcc();
+ auto lhsIndex = analysis.getBinding(lhsVal);
+ auto rhsIndex = analysis.getBinding(rhsVal);
+ auto accIndex = analysis.getBinding(accVal);
+ if (lhsIndex < 0 || rhsIndex < 0 || accIndex < 0)
+ return failure();
+ auto existingLhsIndex = op.getLhsRegIndex();
+ auto existingRhsIndex = op.getRhsRegIndex();
+ auto existingAccIndex = op.getAccRegIndex();
+ if ((existingLhsIndex && *existingLhsIndex != lhsIndex) ||
+ (existingRhsIndex && *existingRhsIndex != rhsIndex) ||
+ (existingAccIndex && *existingAccIndex != accIndex))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<TileMulIOp>(
+ op, op.getRes().getType(), lhsVal, rhsVal, accVal, op.getIsZextLhs(),
+ op.getIsZextRhs(), rewriter.getI8IntegerAttr(lhsIndex),
+ rewriter.getI8IntegerAttr(rhsIndex),
+ rewriter.getI8IntegerAttr(accIndex));
+ return success();
+ }
+};
+
struct EnableAMXTileBindingPass
: public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
void runOnOperation() override {
- // 0. Get AnalyseInfo for each concerned Value (mixed used of tmul & normal
- // vector operations?)
+ // 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
+ // tmul & normal vector operations)
TileBindingAnalysis &analysis = getAnalysis<TileBindingAnalysis>();
+ if (!analysis.isValid())
+ return;
+
+ // 1. Set propagated binding info to AMX Ops
+ RewritePatternSet patterns(&getContext());
+ patterns.add<TileStoreBindingRewriter>(&getContext(), analysis);
+ patterns.add<TileMulFBindingRewriter>(&getContext(), analysis);
+ patterns.add<TileMulIBindingRewriter>(&getContext(), analysis);
+ FrozenRewritePatternSet patternSet(std::move(patterns));
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
+ analysis.setValid(false);
+ return;
+ }
- // 1. Propagate binding info to AMX Ops
- //
// 2. Analyse tile scopes & expand them maximally
//
// 3. insert tile config/release according to tile scopes
>From a44abcfa36b85e70dd26103310b453ba04aa30ef Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Wed, 26 Jun 2024 01:16:17 -0700
Subject: [PATCH 04/17] [WIP] add aux info collection
---
.../AMX/Transforms/EnableAMXTileBinding.cpp | 111 ++++++++++++++++--
1 file changed, 104 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 0308002c27c908..0dae119ecce85e 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -51,7 +51,7 @@ class TileBindingAnalysis {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
explicit TileBindingAnalysis(Operation *);
bool isValid() const { return isValidAnalysis; }
- void setValid(bool v) { isValidAnalysis = v; }
+ // void setValid(bool v) { isValidAnalysis = v; }
int getBinding(Value val) const {
auto iter = bindings.find(val);
if (iter == bindings.end())
@@ -164,6 +164,102 @@ TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
});
}
+// A class for analyzing tile configuration domination (a.k.a. tile scope)
+class TileScopeAnalysis {
+private:
+ typedef llvm::iterator_range<Block::iterator, Block::iterator> BlockSeg;
+ typedef SmallVector<SmallVector<int, 2>, 8> Palette;
+ struct TileScope {
+ BlockSeg seg;
+ Palette palette;
+ };
+
+ bool isValidAnalysis;
+ // Storing Ops that would break tile context & scope (usually parallel Ops)
+ DenseSet<Operation *> scopeBreaker;
+ DenseMap<Operation *, BlockSeg> tileUsage;
+ SmallVector<TileScope, 10> tileScopes;
+
+ void addScopeBreaker(Operation *op) { scopeBreaker.insert(op); }
+ bool isScopeBreaker(Operation *op) {
+ return scopeBreaker.find(op) == scopeBreaker.end();
+ }
+
+ void setTileUsage(Operation *op, BlockSeg seg);
+ BlockSeg getTileUsage();
+ void doTileScope(Block &block);
+ void doTileScope(BlockSeg seg);
+ void doTileScope(Operation *op);
+
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
+ explicit TileScopeAnalysis(Operation *);
+ bool isValid() const { return isValidAnalysis; }
+};
+
+static bool isTileOp(Operation *op) {
+ return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+ llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+ llvm::isa<TileStoreOp>(op);
+}
+
+TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
+ isValidAnalysis = false;
+ func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+ if (!func)
+ return;
+
+ isValidAnalysis = true;
+ // 0. First walk to mark tile scope breakers
+ func->walk<WalkOrder::PostOrder>([this](Operation *op) {
+ if (!isScopeBreaker(op))
+ return;
+
+ if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+ llvm::isa<omp::ParallelOp>(op) || llvm::isa<omp::WsloopOp>(op)) {
+ while (op != root) {
+ addScopeBreaker(op);
+ op = op->getParentOp();
+ }
+ }
+ });
+
+ // 1. Second walk to analyse usage scope for each tile Op
+ func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+ if (!isValidAnalysis)
+ return;
+ if (!isTileOp(op))
+ return;
+ Operation *lastUser = nullptr;
+ for (auto user : op->getUsers())
+ lastUser = user;
+ while (lastUser && op->getBlock() != lastUser->getBlock()) {
+ lastUser = lastUser->getParentOp();
+ if (!lastUser)
+ isValidAnalysis = false;
+ }
+ setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
+ });
+ if (!isValidAnalysis)
+ return;
+
+ // 2. Tile scoping for each segmented region in a recursive manner
+ doTileScope(func.getRegion(0).front());
+}
+
+void TileScopeAnalysis::doTileScope(Block &block) {
+ doTileScope(BlockSeg(block.begin(), block.end()));
+}
+
+void TileScopeAnalysis::doTileScope(BlockSeg seg) {
+ if (seg.empty())
+ return;
+ for (auto probe = seg.begin(); probe < seg.end(); probe++) {
+ }
+}
+
+void TileScopeAnalysis::doTileScope(Operation *op) {}
+
//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//
@@ -274,8 +370,8 @@ struct EnableAMXTileBindingPass
void runOnOperation() override {
// 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
// tmul & normal vector operations)
- TileBindingAnalysis &analysis = getAnalysis<TileBindingAnalysis>();
- if (!analysis.isValid())
+ TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
+ if (!bindingAna.isValid())
return;
// 1. Set propagated binding info to AMX Ops
@@ -285,13 +381,14 @@ struct EnableAMXTileBindingPass
patterns.add<TileMulIBindingRewriter>(&getContext(), analysis);
FrozenRewritePatternSet patternSet(std::move(patterns));
- if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
- analysis.setValid(false);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
return;
- }
// 2. Analyse tile scopes & expand them maximally
- //
+ TileScopeAnalysis &scopeAna = getAnalysis<TileScopeAnalysis>();
+ if (!scopeAna.isValid())
+ return;
+
// 3. insert tile config/release according to tile scopes
}
};
>From 769b74e3b41cdfff18ba5b089fa26e9ab269dcf6 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 28 Jun 2024 02:35:02 -0700
Subject: [PATCH 05/17] [TBD] add tile scoping algorithm
---
.../AMX/Transforms/EnableAMXTileBinding.cpp | 336 +++++++++++++++++-
1 file changed, 319 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 0dae119ecce85e..87459540929853 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -167,27 +167,61 @@ TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
// A class for analyzing tile configuration domination (a.k.a. tile scope)
class TileScopeAnalysis {
private:
- typedef llvm::iterator_range<Block::iterator, Block::iterator> BlockSeg;
- typedef SmallVector<SmallVector<int, 2>, 8> Palette;
+ typedef llvm::iterator_range<Block::iterator> BlockSeg;
+ // A list of 2-dim shapes representing tmm register shape, the length should
+ // always be 8
+ struct PaletteInfo {
+ bool overflow;
+ SmallVector<pair<int, int>, 8> palette;
+ PaletteInfo() {
+ palette.resize(8, {0, 0});
+ clear();
+ }
+ void clear();
+ bool isEmpty(int idx) {
+ return palette[idx].first == 0 && palette[idx].second == 0;
+ }
+ void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
+ void merge(const PaletteInfo &rhs);
+ bool isConflict(const PaletteInfo &rhs);
+ };
struct TileScope {
+ // The BlockSeg here is inclusive (including end Op)
BlockSeg seg;
- Palette palette;
+ PaletteInfo pi;
+ TileScope() { clear(); }
+ void clear() { pi.clear(); }
};
bool isValidAnalysis;
- // Storing Ops that would break tile context & scope (usually parallel Ops)
- DenseSet<Operation *> scopeBreaker;
+ // Storing parallel Ops that would break tile context & scope
+ DenseSet<Operation *> parallelOps;
+ // Storing needed palette info for each concerned Op
+ DenseMap<Operation *, PaletteInfo> neededPalette;
+ // Storing the usage scope for each concerned tile Op
+ // The BlockSeg here is inclusive (including end Op)
DenseMap<Operation *, BlockSeg> tileUsage;
+ // Storing final tile scope results for injecting tilecfg/tilerelease
SmallVector<TileScope, 10> tileScopes;
- void addScopeBreaker(Operation *op) { scopeBreaker.insert(op); }
- bool isScopeBreaker(Operation *op) {
- return scopeBreaker.find(op) == scopeBreaker.end();
+ void addParallelOp(Operation *op) { parallelOps.insert(op); }
+ bool isParallelOp(Operation *op) {
+ return parallelOps.find(op) == parallelOps.end();
}
- void setTileUsage(Operation *op, BlockSeg seg);
- BlockSeg getTileUsage();
+ void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
+
+ PaletteInfo collectRegionPalette(Region ®ion);
+ PaletteInfo collectPalette(Operation *op);
+ // Below two functions are the leaf functinos of recursive collection, will
+ // actually insert PaletteInfo into map storage
+ PaletteInfo collectPaletteForScf(Operation *op);
+ PaletteInfo collectPaletteForTile(Operation *op);
+ std::optional<PaletteInfo> getPalette(Operation *op);
+ std::optional<PaletteInfo> getPalette(BlockSeg seg);
+
void doTileScope(Block &block);
+ // The BlockSeg here is exclusive (excluding end Op)
void doTileScope(BlockSeg seg);
void doTileScope(Operation *op);
@@ -197,6 +231,51 @@ class TileScopeAnalysis {
bool isValid() const { return isValidAnalysis; }
};
+void TileScopeAnalysis::PaletteInfo::clear() {
+ overflow = false;
+ for (int idx = 0; idx < 8; idx++)
+ palette[idx] = {0, 0};
+}
+
+void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
+ if (overflow || rhs.overflow) {
+ overflow = true;
+ return;
+ }
+ for (int idx = 0; idx < 8; idx++) {
+ if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+ if (palette[idx].first != rhs.palette[idx].first ||
+ palette[idx].second != rhs.palette[idx].second) {
+ overflow = true;
+ break;
+ }
+ } else if (!rhs.isEmpty(idx)) {
+ palette[idx] = rhs.palette[idx];
+ }
+ }
+}
+
+bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
+ if (overflow || rhs.overflow) {
+ return true;
+ }
+ for (int idx = 0; idx < 8; idx++) {
+ if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+ if (palette[idx].first != rhs.palette[idx].first ||
+ palette[idx].second != rhs.palette[idx].second) {
+ return true;
+ }
+ }
+ }
+}
+
+static bool isConcernedScfOp(Operation *op) {
+ return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
+ llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+ llvm::isa<scf::WhileOp>(op);
+}
+
static bool isTileOp(Operation *op) {
return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
@@ -210,21 +289,24 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
return;
isValidAnalysis = true;
- // 0. First walk to mark tile scope breakers
+ // 0. First walk to mark parallel Ops
func->walk<WalkOrder::PostOrder>([this](Operation *op) {
- if (!isScopeBreaker(op))
+ if (!isParallelOp(op))
return;
if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
llvm::isa<omp::ParallelOp>(op) || llvm::isa<omp::WsloopOp>(op)) {
while (op != root) {
- addScopeBreaker(op);
+ addParallelOp(op);
op = op->getParentOp();
}
}
});
- // 1. Second walk to analyse usage scope for each tile Op
+ // 1. Second walk to collect needed palette for each concerned Op
+ collectNeededPalette(root);
+
+ // 2. Third walk to analyse usage scope for each tile Op
func->walk<WalkOrder::PreOrder>([this](Operation *op) {
if (!isValidAnalysis)
return;
@@ -243,22 +325,242 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
if (!isValidAnalysis)
return;
- // 2. Tile scoping for each segmented region in a recursive manner
+ // 3. Tile scoping for each segmented region in a recursive manner
doTileScope(func.getRegion(0).front());
}
+PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
+ PaletteInfo pi;
+ for (auto op : block.getOps())
+ pi.merge(collectPalette(op));
+ return pi;
+}
+
+PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
+ if (!isValidAnalysis)
+ return PaletteInfo();
+ if (auto func = dyn_cast_or_null<func::FuncOp>(root))
+ // No need to store PaletteInfo for func
+ return collectRegionPalette(func.getRegion(0).front());
+
+ auto iter = neededPalette.find(op);
+ if (iter != neededPalette.end())
+ return iter->second;
+
+ // For now, we only concern certain control flow Ops and tile Ops
+ if (isConcernedScfOp(op))
+ return collectPaletteForScf(op);
+ if (isTileOp(op))
+ return collectPaletteForTile(op);
+ return PaletteInfo();
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
+ if (!isConcerendScfOp(op))
+ return PaletteInfo();
+
+ PaletteInfo pi;
+ if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+ pi = collectNeededPalette(op->getRegion(0).front());
+ } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+ auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
+ auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
+ pi.merge(thenPalette);
+ pi.merge(elsePalette);
+ } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+ pi = collectRegionPalette(indexOp.getDefaultRegion().front());
+ for (auto &caseRegion : indexOp.getCaseRegions()) {
+ pi.merge(collectRegionPalette(caseRegion.front()));
+ }
+ } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+ auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
+ auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
+ pi.merge(beforePalette);
+ pi.merge(afterPalette);
+ }
+ neededPalette[op] = pi;
+ return pi;
+}
+
+static inline pair<int, int> getPaletteShape(VectorType type) {
+ ArrayRef<int64_t> shape = type.getShape();
+ auto elementType = type.getElementType();
+ int typeSize;
+ if (elementType.isInteger(8))
+ typeSize = 1;
+ else if (elementType.isBF16())
+ typeSize = 2;
+ else if (elementType.isInteger(32) || elementType.isF32())
+ typeSize = 4;
+ else
+ assert(false && "Invalid type for palette");
+
+ // Palette shape is { rows, colBytes }
+ return {shape[0], shape[1] * typeSize};
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
+ if (!isTileOp(op))
+ return PaletteInfo();
+
+#define PROCESS_UNARY_TILE_OP(op, method) \
+ auto index = op.method(); \
+ if (!index) { \
+ isValidAnalysis = false; \
+ return PaletteInfo(); \
+ } \
+ pi.set(*index, getPaletteShape(op.getVectorType()));
+
+#define PROCESS_TRINARY_TILE_OP(op) \
+ auto lhsIndex = tileMulFOp.getLhsRegIndex(); \
+ auto rhsIndex = tileMulFOp.getRhsRegIndex(); \
+ auto accIndex = tileMulFOp.getAccRegIndex(); \
+ if (!lhsIndex || !rhsIndex || !accIndex) { \
+ isValidAnalysis = false; \
+ return PaletteInfo(); \
+ } \
+ pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType())); \
+ pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType())); \
+ pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
+
+ PaletteInfo pi;
+ if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
+ PROCESS_UNARY_TILE_OP(tileLoadOp, getDstRegIndex);
+ } else if (auto tileMulFOp = dyn_cast<TileMulFOp>(op)) {
+ PROCESS_TRINARY_TILE_OP(tileMulFOp);
+ } else if (auto tileMulIOp = dyn_cast<TileMulIOp>(op)) {
+ PROCESS_TRINARY_TILE_OP(tileMulIOp);
+ } else if (auto tileStoreOp = dyn_cast<TileStoreOp>(op)) {
+ PROCESS_UNARY_TILE_OP(tileStoreOp, getSrcRegIndex);
+ } else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
+ PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
+ }
+ neededPalette[op] = pi;
+ return pi;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
+ auto iter = neededPalette.find(op);
+ if (iter == neededPalette.end()) {
+ return std::null_opt;
+ }
+ return iter->second;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
+ bool hasPaletteInfo = false;
+ PaletteInfo pi;
+ for (Operation &opIns : seg) {
+ auto *op = &opIns;
+ auto tmpPi = getPalette(&opIns);
+ if (tmpPi) {
+ hasPaletteInfo = true;
+ pi.merge(*tmpPi);
+ }
+ }
+ return hasPaletteInfo ? pi : std::null_opt;
+}
+
void TileScopeAnalysis::doTileScope(Block &block) {
doTileScope(BlockSeg(block.begin(), block.end()));
}
void TileScopeAnalysis::doTileScope(BlockSeg seg) {
+ if (!isValidAnalysis)
+ return;
if (seg.empty())
return;
- for (auto probe = seg.begin(); probe < seg.end(); probe++) {
+ SmallVector<BlockSeg, 3> blockSegs;
+ SmallVector<Operation *, 3> paraOps;
+ auto currBegin = seg.begin();
+ for (auto probe = seg.begin(); probe != seg.end(); probe++) {
+ Operation *op = &(*probe);
+ if (isParallelOp(op) {
+ blockSegs.push_back(BlockSeg(currBegin, probe));
+ paraOps.push_back(op);
+ currBegin = probe;
+ currBegin++;
+ }
}
+ if (breakers.size()) {
+ assert(blockSegs.size() == paraOps.size());
+ for (int idx = 0; idx < paraOps.size(); idx++) {
+ doTileScope(blockSegs[idx]);
+ doTileScope(paraOps[idx]);
+ }
+ doTileScope(BlockSeg(currBegin, seg.end()));
+ return;
+ }
+
+ // Do tile scope on parallel-free BlockSeg
+ TileScope currScope;
+ std::optional<Block::iterator> currSegStart;
+ Block::iterator currIter = seg.begin();
+ // Traverse BlockSeg and greedily do tile scoping without look ahead
+ while (currIter != seg.end()) {
+ Operation *currOp = &(*currIter);
+ if (!currSegStart)
+ currSegStart = currIter;
+
+ Block::iterator nextIterIfMerge;
+ std::optional<PaletteInfo> pi = std::null_opt;
+ if (isConcernedScfOp(currOp)) {
+ pi = getPalette(currOp);
+ nextIterIfMerge = currIter;
+ nextIterIfMerge++;
+ } else if (isTileOp(currOp)) {
+ auto iter = tileUsage.find(currOp);
+ if (iter == tileUsage.end()) {
+ isValidAnalysis = false;
+ return;
+ }
+ pi = getPalette(iter->second);
+ if (pi && pi->overflow) {
+ // This means the binding info exceeds the hardware capability
+ isValidAnalysis = false;
+ return;
+ }
+ nextIterIfMerge = iter->second.end();
+ nextIterIfMerge++;
+ }
+ if (!pi) {
+ currIter++;
+ continue;
+ }
+
+#define ADD_PREVIOUS_SCOPE() \
+ auto prevIter = currIter; \
+ prevIter--; \
+ currScope.seg = BlockSeg(*currSegStart, prevIter); \
+ tileScopes.push_back(currScope); \
+ currScope.clear(); \
+ currSegStart = std::null_opt;
+
+ if (pi->overflow) {
+ // Only scf Ops could go through this possibility
+ if (currSegStart && *currSegStart != currIter) {
+ ADD_PREVIOUS_SCOPE()
+ }
+ doTileScope(currOp);
+ currIter++;
+ } else {
+ if (currScope.pi.isConflict(*pi)) {
+ ADD_PREVIOUS_SCOPE();
+ currIter++;
+ } else {
+ currScope.pi.merge(*pi);
+ currIter = nextIterIfMerge;
+ }
+ }
+ }
+
+ ADD_PREVIOUS_SCOPE();
}
-void TileScopeAnalysis::doTileScope(Operation *op) {}
+void TileScopeAnalysis::doTileScope(Operation *op) {
+ // TODO: do tile scope for scf here
+}
//===----------------------------------------------------------------------===//
// Pass
>From 986064a75631b610c242887dad759413733bfd7e Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Mon, 1 Jul 2024 02:26:35 -0700
Subject: [PATCH 06/17] finish tile scoping procedure
---
.../AMX/Transforms/EnableAMXTileBinding.cpp | 129 ++++++++++++++----
1 file changed, 106 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 87459540929853..24586ac966a728 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -186,7 +186,7 @@ class TileScopeAnalysis {
bool isConflict(const PaletteInfo &rhs);
};
struct TileScope {
- // The BlockSeg here is inclusive (including end Op)
+ // The BlockSeg here is inclusive (containing `end` Op)
BlockSeg seg;
PaletteInfo pi;
TileScope() { clear(); }
@@ -199,7 +199,7 @@ class TileScopeAnalysis {
// Storing needed palette info for each concerned Op
DenseMap<Operation *, PaletteInfo> neededPalette;
// Storing the usage scope for each concerned tile Op
- // The BlockSeg here is inclusive (including end Op)
+ // The BlockSeg here is inclusive (containing `end` Op)
DenseMap<Operation *, BlockSeg> tileUsage;
// Storing final tile scope results for injecting tilecfg/tilerelease
SmallVector<TileScope, 10> tileScopes;
@@ -221,7 +221,7 @@ class TileScopeAnalysis {
std::optional<PaletteInfo> getPalette(BlockSeg seg);
void doTileScope(Block &block);
- // The BlockSeg here is exclusive (excluding end Op)
+ // The input BlockSeg here is exclusive (not containing `end` Op)
void doTileScope(BlockSeg seg);
void doTileScope(Operation *op);
@@ -269,7 +269,8 @@ bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
}
}
-static bool isConcernedScfOp(Operation *op) {
+// Currently we only operate on scf Ops
+static bool isConcernedControlFlowOp(Operation *op) {
return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
@@ -294,8 +295,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
if (!isParallelOp(op))
return;
- if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
- llvm::isa<omp::ParallelOp>(op) || llvm::isa<omp::WsloopOp>(op)) {
+ if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
while (op != root) {
addParallelOp(op);
op = op->getParentOp();
@@ -348,7 +348,7 @@ PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
return iter->second;
// For now, we only concern certain control flow Ops and tile Ops
- if (isConcernedScfOp(op))
+ if (isConcernedControlFlowOp(op))
return collectPaletteForScf(op);
if (isTileOp(op))
return collectPaletteForTile(op);
@@ -505,7 +505,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
Block::iterator nextIterIfMerge;
std::optional<PaletteInfo> pi = std::null_opt;
- if (isConcernedScfOp(currOp)) {
+ if (isConcernedControlFlowOp(currOp)) {
pi = getPalette(currOp);
nextIterIfMerge = currIter;
nextIterIfMerge++;
@@ -517,7 +517,8 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
}
pi = getPalette(iter->second);
if (pi && pi->overflow) {
- // This means the binding info exceeds the hardware capability
+ // This means the binding info in tile Ops exceeds the hardware
+ // capability
isValidAnalysis = false;
return;
}
@@ -529,24 +530,26 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
continue;
}
-#define ADD_PREVIOUS_SCOPE() \
- auto prevIter = currIter; \
- prevIter--; \
- currScope.seg = BlockSeg(*currSegStart, prevIter); \
- tileScopes.push_back(currScope); \
- currScope.clear(); \
- currSegStart = std::null_opt;
+#define TRY_ADD_PREVIOUS_SCOPE() \
+ if (currSegStart && *currSegStart != currIter) { \
+ auto prevIter = currIter; \
+ prevIter--; \
+ currScope.seg = BlockSeg(*currSegStart, prevIter); \
+ tileScopes.push_back(currScope); \
+ currScope.clear(); \
+ currSegStart = std::null_opt; \
+ }
if (pi->overflow) {
// Only scf Ops could go through this possibility
- if (currSegStart && *currSegStart != currIter) {
- ADD_PREVIOUS_SCOPE()
- }
+ TRY_ADD_PREVIOUS_SCOPE();
doTileScope(currOp);
currIter++;
} else {
if (currScope.pi.isConflict(*pi)) {
- ADD_PREVIOUS_SCOPE();
+ TRY_ADD_PREVIOUS_SCOPE();
+ currScope.pi = *pi;
+ currSegStart = currIter;
currIter++;
} else {
currScope.pi.merge(*pi);
@@ -555,11 +558,39 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
}
}
- ADD_PREVIOUS_SCOPE();
+ TRY_ADD_PREVIOUS_SCOPE();
}
void TileScopeAnalysis::doTileScope(Operation *op) {
- // TODO: do tile scope for scf here
+ // This func try to collect tile scopes for a single control flow Op
+ // This func is not for tile Ops
+ if (isTileOp(op))
+ return;
+ // Ops that invoke this func are either parallelOps or scfOps with overflowed
+ // paletteInfo, and neither of them can form a tile scope by itself, so we
+ // omit checking self-formed tile scope in this func
+ if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+ auto &block = op->getRegion(0).front();
+ doTileScope(BlockSeg(block.begin(), block.end()));
+ } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+ auto &ifBlock = op->getThenRegion().front();
+ auto &elseBlock = op->getElseRegion().front();
+ doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
+ doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
+ } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+ auto &defaultBlock = indexOp.getDefaultRegion().front();
+ doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
+ for (auto &caseRegion : indexOp.getCaseRegions()) {
+ auto &caseBlock = indexOp.getDefaultRegion().front();
+ doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
+ }
+ } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+ auto &beforeBlock = whileOp.getRegion(0).front();
+ auto &afterBlock = whileOp.getRegion(1).front();
+ doTileScope(BlockSeg(beforeBlock.begin(), beforeBlock.end()));
+ doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
+ }
}
//===----------------------------------------------------------------------===//
@@ -669,7 +700,40 @@ class TileMulIBindingRewriter : public OpRewritePattern<TileMulIOp> {
struct EnableAMXTileBindingPass
: public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
+private:
+ bool isViableTileOps() {
+ Operation *root = getOperation();
+ auto func = dyn_cast<func::FuncOp>(root);
+ if (!func)
+ return false;
+
+ bool isViable = true;
+ func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+ if (!isViable)
+ return;
+ if (!isTileOp(op))
+ return;
+ auto probe = op->getParentOp();
+ while (probe != root) {
+ if (!isConcernedControlFlowOp(probe)) {
+ isViable = false;
+ break;
+ }
+ probe = probe->getParentOp();
+ }
+ });
+ return isViable;
+ }
+
+ LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {}
+
+public:
void runOnOperation() override {
+ // Ensure that tile Ops are not wrapped by out-of-scope Ops, else cannot do
+ // enabling
+ if (!isViableTileOps())
+ return;
+
// 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
// tmul & normal vector operations)
TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
@@ -691,7 +755,26 @@ struct EnableAMXTileBindingPass
if (!scopeAna.isValid())
return;
- // 3. insert tile config/release according to tile scopes
+ // 3. Insert tile config/release according to tile scopes
+ OpBuilder builder(getOperation());
+ for (auto &scope : tileScopes) {
+ assert(!scope.pi.overflow && "Expecting legal AMX palette info");
+ auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
+ assert(paletteGlobal && "Failed to create global palette");
+
+ Operation *begin = &(*scope.seg.begin());
+ Loc loc = begin->getLoc();
+
+ builder.setInsertionPoint(begin);
+ Value paletteGlobalPtr =
+ builder.create<LLVM::AddressOfOp>(loc, paletteGlobal);
+ builder.create<amx::x86_amx_ldtilecfg_plain>(loc, paletteGlobalPtr);
+
+ Operation *end = &(*scope.seg.end());
+ loc = end->getLoc();
+ builder.setInsertionPointAfter(end);
+ builder.create<amx::x86_amx_tilerelease_plain>(loc);
+ }
}
};
>From 5faad5a944534894f24eb1d95be070aa1181a040 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 2 Jul 2024 00:25:42 -0700
Subject: [PATCH 07/17] add tile Ops lowering with binding
---
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 +-
.../AMX/Transforms/EnableAMXTileBinding.cpp | 121 ++++++++++-----
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 138 ++++++++++++++++--
3 files changed, 214 insertions(+), 48 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 55143d5939ba25..c3fd9399a6c9e6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -105,8 +105,9 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (amx) {
+ auto &analysis = getCachedAnalysis<TileScopeAnalysis>();
configureAMXLegalizeForExportTarget(target);
- populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
+ populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
}
if (x86Vector) {
configureX86VectorLegalizeForExportTarget(target);
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 24586ac966a728..54a609fa783e04 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -51,7 +51,6 @@ class TileBindingAnalysis {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
explicit TileBindingAnalysis(Operation *);
bool isValid() const { return isValidAnalysis; }
- // void setValid(bool v) { isValidAnalysis = v; }
int getBinding(Value val) const {
auto iter = bindings.find(val);
if (iter == bindings.end())
@@ -81,7 +80,7 @@ static bool TileMulCheck(Operation *op) {
}
// Not allow mixed use of tile Ops and normal vector Ops, any mixing is
-// considered unacceptable
+// considered unacceptable.
static bool isAcceptableTileOp(Operation *op) {
if (!isTileOp(op))
return false;
@@ -164,12 +163,12 @@ TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
});
}
-// A class for analyzing tile configuration domination (a.k.a. tile scope)
+// A class for analyzing tile configuration domination (a.k.a. tile scope).
class TileScopeAnalysis {
private:
typedef llvm::iterator_range<Block::iterator> BlockSeg;
- // A list of 2-dim shapes representing tmm register shape, the length should
- // always be 8
+ // A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
+ // the length should always be 8.
struct PaletteInfo {
bool overflow;
SmallVector<pair<int, int>, 8> palette;
@@ -186,7 +185,7 @@ class TileScopeAnalysis {
bool isConflict(const PaletteInfo &rhs);
};
struct TileScope {
- // The BlockSeg here is inclusive (containing `end` Op)
+ // The BlockSeg here is inclusive (containing `end` Op).
BlockSeg seg;
PaletteInfo pi;
TileScope() { clear(); }
@@ -194,14 +193,14 @@ class TileScopeAnalysis {
};
bool isValidAnalysis;
- // Storing parallel Ops that would break tile context & scope
+ // Storing parallel Ops that would break tile context & scope.
DenseSet<Operation *> parallelOps;
- // Storing needed palette info for each concerned Op
+ // Storing needed palette info for each concerned Op.
DenseMap<Operation *, PaletteInfo> neededPalette;
- // Storing the usage scope for each concerned tile Op
- // The BlockSeg here is inclusive (containing `end` Op)
+ // Storing the usage scope for each concerned tile Op.
+ // The BlockSeg here is inclusive (containing `end` Op).
DenseMap<Operation *, BlockSeg> tileUsage;
- // Storing final tile scope results for injecting tilecfg/tilerelease
+ // Storing final tile scope results for injecting tilecfg/tilerelease.
SmallVector<TileScope, 10> tileScopes;
void addParallelOp(Operation *op) { parallelOps.insert(op); }
@@ -214,14 +213,14 @@ class TileScopeAnalysis {
PaletteInfo collectRegionPalette(Region ®ion);
PaletteInfo collectPalette(Operation *op);
// Below two functions are the leaf functinos of recursive collection, will
- // actually insert PaletteInfo into map storage
+ // actually insert PaletteInfo into map storage.
PaletteInfo collectPaletteForScf(Operation *op);
PaletteInfo collectPaletteForTile(Operation *op);
std::optional<PaletteInfo> getPalette(Operation *op);
std::optional<PaletteInfo> getPalette(BlockSeg seg);
void doTileScope(Block &block);
- // The input BlockSeg here is exclusive (not containing `end` Op)
+ // The input BlockSeg here is exclusive (not containing `end` Op).
void doTileScope(BlockSeg seg);
void doTileScope(Operation *op);
@@ -269,7 +268,7 @@ bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
}
}
-// Currently we only operate on scf Ops
+// Currently we only operate on scf Ops.
static bool isConcernedControlFlowOp(Operation *op) {
return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
@@ -290,7 +289,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
return;
isValidAnalysis = true;
- // 0. First walk to mark parallel Ops
+ // 0. First walk to mark parallel Ops.
func->walk<WalkOrder::PostOrder>([this](Operation *op) {
if (!isParallelOp(op))
return;
@@ -303,10 +302,10 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
}
});
- // 1. Second walk to collect needed palette for each concerned Op
+ // 1. Second walk to collect needed palette for each concerned Op.
collectNeededPalette(root);
- // 2. Third walk to analyse usage scope for each tile Op
+ // 2. Third walk to analyse usage scope for each tile Op.
func->walk<WalkOrder::PreOrder>([this](Operation *op) {
if (!isValidAnalysis)
return;
@@ -325,7 +324,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
if (!isValidAnalysis)
return;
- // 3. Tile scoping for each segmented region in a recursive manner
+ // 3. Tile scoping for each segmented region in a recursive manner.
doTileScope(func.getRegion(0).front());
}
@@ -340,14 +339,14 @@ PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
if (!isValidAnalysis)
return PaletteInfo();
if (auto func = dyn_cast_or_null<func::FuncOp>(root))
- // No need to store PaletteInfo for func
+ // No need to store PaletteInfo for func.
return collectRegionPalette(func.getRegion(0).front());
auto iter = neededPalette.find(op);
if (iter != neededPalette.end())
return iter->second;
- // For now, we only concern certain control flow Ops and tile Ops
+ // For now, we only concern certain control flow Ops and tile Ops.
if (isConcernedControlFlowOp(op))
return collectPaletteForScf(op);
if (isTileOp(op))
@@ -396,7 +395,7 @@ static inline pair<int, int> getPaletteShape(VectorType type) {
else
assert(false && "Invalid type for palette");
- // Palette shape is { rows, colBytes }
+ // Palette shape is { rows, colBytes }.
return {shape[0], shape[1] * typeSize};
}
@@ -493,11 +492,11 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
return;
}
- // Do tile scope on parallel-free BlockSeg
+ // Do tile scope on parallel-free BlockSeg.
TileScope currScope;
std::optional<Block::iterator> currSegStart;
Block::iterator currIter = seg.begin();
- // Traverse BlockSeg and greedily do tile scoping without look ahead
+ // Traverse BlockSeg and greedily do tile scoping without look ahead.
while (currIter != seg.end()) {
Operation *currOp = &(*currIter);
if (!currSegStart)
@@ -518,7 +517,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
pi = getPalette(iter->second);
if (pi && pi->overflow) {
// This means the binding info in tile Ops exceeds the hardware
- // capability
+ // capability.
isValidAnalysis = false;
return;
}
@@ -541,7 +540,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
}
if (pi->overflow) {
- // Only scf Ops could go through this possibility
+ // Only scf Ops could go through this possibility.
TRY_ADD_PREVIOUS_SCOPE();
doTileScope(currOp);
currIter++;
@@ -563,12 +562,12 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
void TileScopeAnalysis::doTileScope(Operation *op) {
// This func try to collect tile scopes for a single control flow Op
- // This func is not for tile Ops
+ // This func is not for tile Ops.
if (isTileOp(op))
return;
// Ops that invoke this func are either parallelOps or scfOps with overflowed
// paletteInfo, and neither of them can form a tile scope by itself, so we
- // omit checking self-formed tile scope in this func
+ // omit checking self-formed tile scope in this func.
if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
auto &block = op->getRegion(0).front();
@@ -698,6 +697,14 @@ class TileMulIBindingRewriter : public OpRewritePattern<TileMulIOp> {
}
};
+static inline void uint8ArrayToHex(std::string &out, uint8_t array[],
+ int size) {
+ llvm::raw_string_ostream os(out);
+ for (int index = 0; index < size; index++) {
+ os << format_hex_no_prefix(array[index], 2, true);
+ }
+}
+
struct EnableAMXTileBindingPass
: public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
private:
@@ -725,22 +732,70 @@ struct EnableAMXTileBindingPass
return isViable;
}
- LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {}
+ LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {
+ assert(!pi.overflow && "Expecting valid palette");
+// Pack struct so it can fit into a single 64-byte cache line.
+#pragma pack(push, 1)
+ struct {
+ uint8_t paletteId;
+ uint8_t startRow;
+ uint8_t reserved[14];
+ uint16_t cols[16];
+ uint8_t rows[16];
+ } paletteConfig;
+#pragma pack(pop)
+
+ size_t paletteArraySize = 64;
+ uint8_t *paletteAsArray = &paletteConfig;
+ memset(paletteAsArray, 0x0, paletteArraySize);
+ // Intel AMX: The only legal non-INIT value for palette_id is 1.
+ // TODO(haixin): fetch from CPUID ?
+ paletteConfig.paletteId = 1;
+ for (int index = 0; index < 8; index++) {
+ const auto ®Shape = pi.palette[index];
+ paletteConfig.rows[index] = regShape.first;
+ paletteConfig.cols[index] = regShape.second;
+ }
+
+ std::string paletteSymName = "g_intel_amx_palette_";
+ uintArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
+
+ if ((global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName)))
+ return global;
+ // Create a global symbol containing palette config.
+ ModuleOp moduleOp = getOperation()->template getParentOfType<ModuleOp>();
+ OpBuilder builder(moduleOp);
+ builder.setInsertionPointToStart(moduleOp.getBody());
+
+ SmallVector<uint8_t> elementVals;
+ for (size_t index = 0; index < paletteArraySize; index++)
+ elementVals.push_back(paletteAsArray[index]);
+ auto dataAttrType = RankedTensorType::get(
+ {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+ auto dataAttr =
+ DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+ auto arrayTy =
+ LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+ auto global = builder.create<LLVM::GlobalOp>(
+ moduleOp.getLoc(), arrayType, /*isConstant*/ true,
+ LLVM::Linkage::Private, paletteSymName, dataAttr, /*alignment=*/64);
+ return global;
+ }
public:
void runOnOperation() override {
// Ensure that tile Ops are not wrapped by out-of-scope Ops, else cannot do
- // enabling
+ // enabling.
if (!isViableTileOps())
return;
// 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
- // tmul & normal vector operations)
+ // tmul & normal vector operations).
TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
if (!bindingAna.isValid())
return;
- // 1. Set propagated binding info to AMX Ops
+ // 1. Set propagated binding info to AMX Ops.
RewritePatternSet patterns(&getContext());
patterns.add<TileStoreBindingRewriter>(&getContext(), analysis);
patterns.add<TileMulFBindingRewriter>(&getContext(), analysis);
@@ -750,12 +805,12 @@ struct EnableAMXTileBindingPass
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
return;
- // 2. Analyse tile scopes & expand them maximally
+ // 2. Analyse tile scopes & expand them maximally.
TileScopeAnalysis &scopeAna = getAnalysis<TileScopeAnalysis>();
if (!scopeAna.isValid())
return;
- // 3. Insert tile config/release according to tile scopes
+ // 3. Insert tile config/release according to tile scopes.
OpBuilder builder(getOperation());
for (auto &scope : tileScopes) {
assert(!scope.pi.overflow && "Expecting legal AMX palette info");
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 5e2aa216cf4123..154987303be673 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -74,10 +74,28 @@ Value getStride(ConversionPatternRewriter &rewriter,
}
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
+private:
+ const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
+ TileZeroConversion(const LLVMTypeConverter &typeConverter,
+ const std::optional<TileScopeAnalysis> &analysis)
+ : ConvertOpToLLVMPattern<TileZeroOp>(typeConverter),
+ enablingAnalysis(analysis) {}
+
LogicalResult
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (enablingAnalysis && enablingAnalysis->isValid()) {
+ // Routine for lowering tile Ops with binding info.
+ auto dstRegIndex = op.getDstRegIndex();
+ assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero_plain>(op,
+ *dstRegIndex);
+ return success();
+ }
+
VectorType vType = op.getVectorType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
@@ -91,24 +109,42 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
};
struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
+private:
+ const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
+ TileLoadConversion(const LLVMTypeConverter &typeConverter,
+ const std::optional<TileScopeAnalysis> &analysis)
+ : ConvertOpToLLVMPattern<TileLoadOp>(typeConverter),
+ enablingAnalysis(analysis) {}
LogicalResult
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
- VectorType vType = op.getVectorType();
- // Determine m x n tile sizes.
- std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
// Determine stride.
if (failed(verifyStride(mType)))
return failure();
Value stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
- // Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
+
+ if (enablingAnalysis && enablingAnalysis->isValid()) {
+ // Routine for lowering tile Ops with binding info.
+ auto dstRegIndex = op.getDstRegIndex();
+ assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64_plain>(
+ op, ptr, stride, *dstRegIndex);
+ return success();
+ }
+
+ VectorType vType = op.getVectorType();
+ // Determine m x n tile sizes.
+ std::pair<Value, Value> tsz =
+ getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ // Replace operation with intrinsic.
Type resType = typeConverter->convertType(vType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
op, resType, tsz.first, tsz.second, ptr, stride);
@@ -117,24 +153,42 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
};
struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
+private:
+ const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+ TileStoreConversion(const LLVMTypeConverter &typeConverter,
+ const std::optional<TileScopeAnalysis> &analysis)
+ : ConvertOpToLLVMPattern<TileStoreOp>(typeConverter),
+ enablingAnalysis(analysis) {}
LogicalResult
matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
- VectorType vType = op.getVectorType();
- // Determine m x n tile sizes.
- std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
// Determine stride.
if (failed(verifyStride(mType)))
return failure();
Value stride = getStride(rewriter, *getTypeConverter(), mType,
adaptor.getBase(), op.getLoc());
- // Replace operation with intrinsic.
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
+
+ if (enablingAnalysis && enablingAnalysis->isValid()) {
+ // Routine for lowering tile Ops with binding info.
+ auto srcRegIndex = op.getSrcRegIndex();
+ assert(srcRegIndex && "Incomplete operation attribute for tile binding");
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64_plain>(
+ op, ptr, stride, *srcRegIndex);
+ return success();
+ }
+
+ VectorType vType = op.getVectorType();
+ // Determine m x n tile sizes.
+ std::pair<Value, Value> tsz =
+ getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ // Replace operation with intrinsic.
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
return success();
@@ -142,10 +196,32 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
};
struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
+private:
+ const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
+ TileMulFConversion(const LLVMTypeConverter &typeConverter,
+ const std::optional<TileScopeAnalysis> &analysis)
+ : ConvertOpToLLVMPattern<TileMulFOp>(typeConverter),
+ enablingAnalysis(analysis) {}
+
LogicalResult
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (enablingAnalysis && enablingAnalysis->isValid()) {
+ // Routine for lowering tile Ops with binding info.
+ auto lhsRegIndex = op.getSrcRegIndex();
+ auto rhsRegIndex = op.getRhsRegIndex();
+ auto accRegIndex = op.getAccRegIndex();
+
+ assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+ "Incomplete operation attribute for tile binding");
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps_plain>(
+ op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ return success();
+ }
+
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();
@@ -164,10 +240,45 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
};
struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
+private:
+ const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
+ TileMulIConversion(const LLVMTypeConverter &typeConverter,
+ const std::optional<TileScopeAnalysis> &analysis)
+ : ConvertOpToLLVMPattern<TileMulIOp>(typeConverter),
+ enablingAnalysis(analysis) {}
+
LogicalResult
matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ bool zexta = op.getIsZextLhs();
+ bool zextb = op.getIsZextRhs();
+
+ if (enablingAnalysis && enablingAnalysis->isValid()) {
+ // Routine for lowering tile Ops with binding info.
+ auto lhsRegIndex = op.getSrcRegIndex();
+ auto rhsRegIndex = op.getRhsRegIndex();
+ auto accRegIndex = op.getAccRegIndex();
+
+ assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+ "Incomplete operation attribute for tile binding");
+ if (zexta && zextb)
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud_plain>(
+ op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ else if (zexta && !zextb)
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd_plain>(
+ op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ else if (!zexta && zextb)
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud_plain>(
+ op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ else
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd_plain>(
+ op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ return success();
+ }
+
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();
@@ -178,8 +289,6 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
// Replace operation with intrinsic.
Type resType = typeConverter->convertType(cType);
- bool zexta = op.getIsZextLhs();
- bool zextb = op.getIsZextRhs();
if (zexta && zextb)
rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
@@ -203,9 +312,10 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ LLVMTypeConverter &converter, std::optional<TileScopeAnalysis> &analysis,
+ RewritePatternSet &patterns) {
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
- TileMulFConversion, TileMulIConversion>(converter);
+ TileMulFConversion, TileMulIConversion>(converter, analysis);
}
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
>From eb7672581ce234d62568587b9b16a87a982ebf1b Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 2 Jul 2024 01:30:43 -0700
Subject: [PATCH 08/17] rearrange code structure
---
.../Dialect/AMX/Analysis/AMXBindingAnalysis.h | 121 ++++
.../AMX/Analysis/AMXBindingAnalysis.cpp | 487 +++++++++++++++
mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt | 10 +
mlir/lib/Dialect/AMX/CMakeLists.txt | 1 +
.../lib/Dialect/AMX/Transforms/CMakeLists.txt | 1 +
.../AMX/Transforms/EnableAMXTileBinding.cpp | 562 +-----------------
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 1 +
7 files changed, 623 insertions(+), 560 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
create mode 100644 mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
create mode 100644 mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
new file mode 100644
index 00000000000000..05ff7f3d4a74cb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -0,0 +1,121 @@
+//===- AMXBindingAnalysis.h - Analyse AMX Binding Info --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementations of two analysis:
+// 1. TileBindingAnalysis: verify and propagate pre-set tile register binding
+// info for `vector`s used by tile operations.
+// 2. TileScopeAnalysis: verify and find out proper tile configuration
+// domination scopes for tile operations w.r.t correctness and performance.
+// These analysis would be invoked by pass `--enable-amx-tile-binding` and used
+// as indicator for tile binding info validation in pass
+// `--convert-vector-to-llvm`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_ANALYSIS_AMXBINDINGANALYSIS_H_
+#define MLIR_DIALECT_AMX_ANALYSIS_AMXBINDINGANALYSIS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace amx {
+
+/// A class for analyzing (propagating) tile register binding for each tile
+/// vector.
+class TileBindingAnalysis {
+private:
+ bool isValidAnalysis;
+ DenseMap<Value, int> bindings;
+
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
+ explicit TileBindingAnalysis(Operation *);
+ bool isValid() const { return isValidAnalysis; }
+ int getBinding(Value val) const {
+ auto iter = bindings.find(val);
+ if (iter == bindings.end())
+ return -1;
+ return iter->second;
+ }
+ void setBinding(Value val, int index) { bindings[val] = index; }
+};
+
+// A class for analyzing tile configuration domination (a.k.a. tile scope).
+class TileScopeAnalysis {
+private:
+ typedef llvm::iterator_range<Block::iterator> BlockSeg;
+ // A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
+ // the length should always be 8.
+ struct PaletteInfo {
+ bool overflow;
+ SmallVector<pair<int, int>, 8> palette;
+ PaletteInfo() {
+ palette.resize(8, {0, 0});
+ clear();
+ }
+ void clear();
+ bool isEmpty(int idx) {
+ return palette[idx].first == 0 && palette[idx].second == 0;
+ }
+ void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
+ void merge(const PaletteInfo &rhs);
+ bool isConflict(const PaletteInfo &rhs);
+ };
+ struct TileScope {
+ // The BlockSeg here is inclusive (containing `end` Op).
+ BlockSeg seg;
+ PaletteInfo pi;
+ TileScope() { clear(); }
+ void clear() { pi.clear(); }
+ };
+
+ bool isValidAnalysis;
+ // Storing parallel Ops that would break tile context & scope.
+ DenseSet<Operation *> parallelOps;
+ // Storing needed palette info for each concerned Op.
+ DenseMap<Operation *, PaletteInfo> neededPalette;
+ // Storing the usage scope for each concerned tile Op.
+ // The BlockSeg here is inclusive (containing `end` Op).
+ DenseMap<Operation *, BlockSeg> tileUsage;
+ // Storing final tile scope results for injecting tilecfg/tilerelease.
+ SmallVector<TileScope, 10> tileScopes;
+
+ void addParallelOp(Operation *op) { parallelOps.insert(op); }
+ bool isParallelOp(Operation *op) {
+ return parallelOps.find(op) == parallelOps.end();
+ }
+
+ void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
+
+ PaletteInfo collectRegionPalette(Region ®ion);
+ PaletteInfo collectPalette(Operation *op);
+ // Below two functions are the leaf functinos of recursive collection, will
+ // actually insert PaletteInfo into map storage.
+ PaletteInfo collectPaletteForScf(Operation *op);
+ PaletteInfo collectPaletteForTile(Operation *op);
+ std::optional<PaletteInfo> getPalette(Operation *op);
+ std::optional<PaletteInfo> getPalette(BlockSeg seg);
+
+ void doTileScope(Block &block);
+ // The input BlockSeg here is exclusive (not containing `end` Op).
+ void doTileScope(BlockSeg seg);
+ void doTileScope(Operation *op);
+
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
+ explicit TileScopeAnalysis(Operation *);
+ bool isValid() const { return isValidAnalysis; }
+};
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_ANALYSIS_AMXBINDINGANALYSIS_H_
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
new file mode 100644
index 00000000000000..8ed992ac498fd9
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -0,0 +1,487 @@
+//===- AMXBindingAnalysis.cpp - Binding info analysis for Intel AMX -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementations of two analysis:
+// 1. TileBindingAnalysis
+// 2. TileScopeAnalysis
+// For more details, please refer to header definition.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/BuiltinOps.h"
+
+#define DEBUG_TYPE "amx-binding-analysis"
+
+static bool isTileOp(Operation *op) {
+ return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+ llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+ llvm::isa<TileStoreOp>(op);
+}
+
+template <typename Op>
+static bool TileMulCheck(Operation *op) {
+ auto tile_mul = dyn_cast_or_null<Op>(op);
+ assert(tile_mul);
+
+ auto lhsOp = tile_mul.getLhs().getDefiningOp();
+ auto rhsOp = tile_mul.getRhs().getDefiningOp();
+ auto accOp = tile_mul.getAcc().getDefiningOp();
+ if (!isTileOp(lhsOp) || !isTileOp(rhsOp) || !isTileOp(accOp))
+ return false;
+ return true;
+}
+
+// Not allow mixed use of tile Ops and normal vector Ops, any mixing is
+// considered unacceptable.
+static bool isAcceptableTileOp(Operation *op) {
+ if (!isTileOp(op))
+ return false;
+
+ if (llvm::isa<TileMulFOp>(op)) {
+ return TileMulCheck<TileMulFOp>(op);
+ } else if (llvm::isa<TileMulIOp>(op)) {
+ return TileMulCheck<TileMulIOp>(op);
+ } else if (auto tileStore = dyn_cast_or_null<TileStoreOp>(op)) {
+ auto valOp = tileStore.getVal().getDefiningOp();
+ if (!isTileOp(valOp))
+ return false;
+ }
+ return true;
+}
+
+template <typename Op>
+static bool TileDstPropagate(TileBindingAnalysis *analysis, Operation *op) {
+ auto tileDst = dyn_cast_or_null<Op>(op);
+ assert(tileDst);
+ std::optional<int8_t> tmmIndex = tileDst.getDstRegIndex();
+ if (!tmmIndex) {
+ return false;
+ }
+ analysis->setBinding(tileDst.getRes(), *tmmIndex);
+ return true;
+}
+
+template <typename Op>
+static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
+ auto tileMul = dyn_cast_or_null<Op>(op);
+ assert(tileMul);
+ auto accVal = tileMul.getAcc();
+ auto accIndex = analysis->getBinding(accVal);
+ if (accIndex < 0)
+ return false;
+
+ analysis->setBinding(tileMul.getRes(), accIndex);
+ return true;
+}
+
+TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
+ isValidAnalysis = false;
+ func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+ if (!func)
+ return;
+
+ isValidAnalysis = true;
+ func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+ if (!isValidAnalysis)
+ return;
+ if (!isTileOp(op))
+ return;
+ if (!isAcceptableTileOp(op)) {
+ isValidAnalysis = false;
+ return;
+ }
+
+ if (llvm::isa<TileZeroOp>(op)) {
+ if (!TileDstPropagate<TileZeroOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ } else if (llvm::isa<TileLoadOp>(op)) {
+ if (!TileDstPropagate<TileLoadOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ } else if (llvm::isa<TileMulFOp>(op)) {
+ if (!TileMulPropagate<TileMulFOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ } else if (llvm::isa<TileMulIOp>(op)) {
+ if (!TileMulPropagate<TileMulIOp>(this, op)) {
+ isValidAnalysis = false;
+ return;
+ }
+ }
+ });
+}
+
+void TileScopeAnalysis::PaletteInfo::clear() {
+ overflow = false;
+ for (int idx = 0; idx < 8; idx++)
+ palette[idx] = {0, 0};
+}
+
+void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
+ if (overflow || rhs.overflow) {
+ overflow = true;
+ return;
+ }
+ for (int idx = 0; idx < 8; idx++) {
+ if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+ if (palette[idx].first != rhs.palette[idx].first ||
+ palette[idx].second != rhs.palette[idx].second) {
+ overflow = true;
+ break;
+ }
+ } else if (!rhs.isEmpty(idx)) {
+ palette[idx] = rhs.palette[idx];
+ }
+ }
+}
+
+bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
+ if (overflow || rhs.overflow) {
+ return true;
+ }
+ for (int idx = 0; idx < 8; idx++) {
+ if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+ if (palette[idx].first != rhs.palette[idx].first ||
+ palette[idx].second != rhs.palette[idx].second) {
+ return true;
+ }
+ }
+ }
+}
+
+// Currently we only operate on scf Ops.
+static bool isConcernedControlFlowOp(Operation *op) {
+ return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
+ llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+ llvm::isa<scf::WhileOp>(op);
+}
+
+static bool isTileOp(Operation *op) {
+ return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+ llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+ llvm::isa<TileStoreOp>(op);
+}
+
+TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
+ isValidAnalysis = false;
+ func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+ if (!func)
+ return;
+
+ isValidAnalysis = true;
+ // 0. First walk to mark parallel Ops.
+ func->walk<WalkOrder::PostOrder>([this](Operation *op) {
+ if (!isParallelOp(op))
+ return;
+
+ if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+ while (op != root) {
+ addParallelOp(op);
+ op = op->getParentOp();
+ }
+ }
+ });
+
+ // 1. Second walk to collect needed palette for each concerned Op.
+ collectNeededPalette(root);
+
+ // 2. Third walk to analyse usage scope for each tile Op.
+ func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+ if (!isValidAnalysis)
+ return;
+ if (!isTileOp(op))
+ return;
+ Operation *lastUser = nullptr;
+ for (auto user : op->getUsers())
+ lastUser = user;
+ while (lastUser && op->getBlock() != lastUser->getBlock()) {
+ lastUser = lastUser->getParentOp();
+ if (!lastUser)
+ isValidAnalysis = false;
+ }
+ setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
+ });
+ if (!isValidAnalysis)
+ return;
+
+ // 3. Tile scoping for each segmented region in a recursive manner.
+ doTileScope(func.getRegion(0).front());
+}
+
+PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
+ PaletteInfo pi;
+ for (auto op : block.getOps())
+ pi.merge(collectPalette(op));
+ return pi;
+}
+
+PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
+ if (!isValidAnalysis)
+ return PaletteInfo();
+ if (auto func = dyn_cast_or_null<func::FuncOp>(root))
+ // No need to store PaletteInfo for func.
+ return collectRegionPalette(func.getRegion(0).front());
+
+ auto iter = neededPalette.find(op);
+ if (iter != neededPalette.end())
+ return iter->second;
+
+ // For now, we only concern certain control flow Ops and tile Ops.
+ if (isConcernedControlFlowOp(op))
+ return collectPaletteForScf(op);
+ if (isTileOp(op))
+ return collectPaletteForTile(op);
+ return PaletteInfo();
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
+ if (!isConcerendScfOp(op))
+ return PaletteInfo();
+
+ PaletteInfo pi;
+ if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+ pi = collectNeededPalette(op->getRegion(0).front());
+ } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+ auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
+ auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
+ pi.merge(thenPalette);
+ pi.merge(elsePalette);
+ } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+ pi = collectRegionPalette(indexOp.getDefaultRegion().front());
+ for (auto &caseRegion : indexOp.getCaseRegions()) {
+ pi.merge(collectRegionPalette(caseRegion.front()));
+ }
+ } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+ auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
+ auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
+ pi.merge(beforePalette);
+ pi.merge(afterPalette);
+ }
+ neededPalette[op] = pi;
+ return pi;
+}
+
+static inline pair<int, int> getPaletteShape(VectorType type) {
+ ArrayRef<int64_t> shape = type.getShape();
+ auto elementType = type.getElementType();
+ int typeSize;
+ if (elementType.isInteger(8))
+ typeSize = 1;
+ else if (elementType.isBF16())
+ typeSize = 2;
+ else if (elementType.isInteger(32) || elementType.isF32())
+ typeSize = 4;
+ else
+ assert(false && "Invalid type for palette");
+
+ // Palette shape is { rows, colBytes }.
+ return {shape[0], shape[1] * typeSize};
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
+ if (!isTileOp(op))
+ return PaletteInfo();
+
+#define PROCESS_UNARY_TILE_OP(op, method) \
+ auto index = op.method(); \
+ if (!index) { \
+ isValidAnalysis = false; \
+ return PaletteInfo(); \
+ } \
+ pi.set(*index, getPaletteShape(op.getVectorType()));
+
+#define PROCESS_TRINARY_TILE_OP(op) \
+ auto lhsIndex = tileMulFOp.getLhsRegIndex(); \
+ auto rhsIndex = tileMulFOp.getRhsRegIndex(); \
+ auto accIndex = tileMulFOp.getAccRegIndex(); \
+ if (!lhsIndex || !rhsIndex || !accIndex) { \
+ isValidAnalysis = false; \
+ return PaletteInfo(); \
+ } \
+ pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType())); \
+ pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType())); \
+ pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
+
+ PaletteInfo pi;
+ if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
+ PROCESS_UNARY_TILE_OP(tileLoadOp, getDstRegIndex);
+ } else if (auto tileMulFOp = dyn_cast<TileMulFOp>(op)) {
+ PROCESS_TRINARY_TILE_OP(tileMulFOp);
+ } else if (auto tileMulIOp = dyn_cast<TileMulIOp>(op)) {
+ PROCESS_TRINARY_TILE_OP(tileMulIOp);
+ } else if (auto tileStoreOp = dyn_cast<TileStoreOp>(op)) {
+ PROCESS_UNARY_TILE_OP(tileStoreOp, getSrcRegIndex);
+ } else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
+ PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
+ }
+ neededPalette[op] = pi;
+ return pi;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
+ auto iter = neededPalette.find(op);
+ if (iter == neededPalette.end()) {
+ return std::null_opt;
+ }
+ return iter->second;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
+ bool hasPaletteInfo = false;
+ PaletteInfo pi;
+ for (Operation &opIns : seg) {
+ auto *op = &opIns;
+ auto tmpPi = getPalette(&opIns);
+ if (tmpPi) {
+ hasPaletteInfo = true;
+ pi.merge(*tmpPi);
+ }
+ }
+ return hasPaletteInfo ? pi : std::null_opt;
+}
+
+void TileScopeAnalysis::doTileScope(Block &block) {
+ doTileScope(BlockSeg(block.begin(), block.end()));
+}
+
+void TileScopeAnalysis::doTileScope(BlockSeg seg) {
+ if (!isValidAnalysis)
+ return;
+ if (seg.empty())
+ return;
+ SmallVector<BlockSeg, 3> blockSegs;
+ SmallVector<Operation *, 3> paraOps;
+ auto currBegin = seg.begin();
+ for (auto probe = seg.begin(); probe != seg.end(); probe++) {
+ Operation *op = &(*probe);
+ if (isParallelOp(op) {
+ blockSegs.push_back(BlockSeg(currBegin, probe));
+ paraOps.push_back(op);
+ currBegin = probe;
+ currBegin++;
+ }
+ }
+ if (breakers.size()) {
+ assert(blockSegs.size() == paraOps.size());
+ for (int idx = 0; idx < paraOps.size(); idx++) {
+ doTileScope(blockSegs[idx]);
+ doTileScope(paraOps[idx]);
+ }
+ doTileScope(BlockSeg(currBegin, seg.end()));
+ return;
+ }
+
+ // Do tile scope on parallel-free BlockSeg.
+ TileScope currScope;
+ std::optional<Block::iterator> currSegStart;
+ Block::iterator currIter = seg.begin();
+ // Traverse BlockSeg and greedily do tile scoping without look ahead.
+ while (currIter != seg.end()) {
+ Operation *currOp = &(*currIter);
+ if (!currSegStart)
+ currSegStart = currIter;
+
+ Block::iterator nextIterIfMerge;
+ std::optional<PaletteInfo> pi = std::null_opt;
+ if (isConcernedControlFlowOp(currOp)) {
+ pi = getPalette(currOp);
+ nextIterIfMerge = currIter;
+ nextIterIfMerge++;
+ } else if (isTileOp(currOp)) {
+ auto iter = tileUsage.find(currOp);
+ if (iter == tileUsage.end()) {
+ isValidAnalysis = false;
+ return;
+ }
+ pi = getPalette(iter->second);
+ if (pi && pi->overflow) {
+ // This means the binding info in tile Ops exceeds the hardware
+ // capability.
+ isValidAnalysis = false;
+ return;
+ }
+ nextIterIfMerge = iter->second.end();
+ nextIterIfMerge++;
+ }
+ if (!pi) {
+ currIter++;
+ continue;
+ }
+
+#define TRY_ADD_PREVIOUS_SCOPE() \
+ if (currSegStart && *currSegStart != currIter) { \
+ auto prevIter = currIter; \
+ prevIter--; \
+ currScope.seg = BlockSeg(*currSegStart, prevIter); \
+ tileScopes.push_back(currScope); \
+ currScope.clear(); \
+ currSegStart = std::null_opt; \
+ }
+
+ if (pi->overflow) {
+ // Only scf Ops could go through this possibility.
+ TRY_ADD_PREVIOUS_SCOPE();
+ doTileScope(currOp);
+ currIter++;
+ } else {
+ if (currScope.pi.isConflict(*pi)) {
+ TRY_ADD_PREVIOUS_SCOPE();
+ currScope.pi = *pi;
+ currSegStart = currIter;
+ currIter++;
+ } else {
+ currScope.pi.merge(*pi);
+ currIter = nextIterIfMerge;
+ }
+ }
+ }
+
+ TRY_ADD_PREVIOUS_SCOPE();
+}
+
+void TileScopeAnalysis::doTileScope(Operation *op) {
+ // This func try to collect tile scopes for a single control flow Op
+ // This func is not for tile Ops.
+ if (isTileOp(op))
+ return;
+ // Ops that invoke this func are either parallelOps or scfOps with overflowed
+ // paletteInfo, and neither of them can form a tile scope by itself, so we
+ // omit checking self-formed tile scope in this func.
+ if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+ auto &block = op->getRegion(0).front();
+ doTileScope(BlockSeg(block.begin(), block.end()));
+ } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+ auto &ifBlock = op->getThenRegion().front();
+ auto &elseBlock = op->getElseRegion().front();
+ doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
+ doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
+ } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+ auto &defaultBlock = indexOp.getDefaultRegion().front();
+ doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
+ for (auto &caseRegion : indexOp.getCaseRegions()) {
+ auto &caseBlock = indexOp.getDefaultRegion().front();
+ doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
+ }
+ } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+ auto &beforeBlock = whileOp.getRegion(0).front();
+ auto &afterBlock = whileOp.getRegion(1).front();
+ doTileScope(BlockSeg(beforeBlock.begin(), beforeBlock.end()));
+ doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
+ }
+}
diff --git a/mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt b/mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt
new file mode 100644
index 00000000000000..008fda4357a6b3
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_dialect_library(MLIRAMXAnalysis
+ AMXBindingAnalysis.cpp
+
+ DEPENDS
+ MLIRAMXConversionsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAMXDialect
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
index 9f57627c321fb0..cd0e2051f32095 100644
--- a/mlir/lib/Dialect/AMX/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(Analysis)
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index a0c1c10f3c1701..351f788959a6de 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRAMXTransforms
DEPENDS
MLIRAMXConversionsIncGen
+ MLIRAMXAnalysis
LINK_LIBS PUBLIC
MLIRAMXDialect
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 54a609fa783e04..3a181ed79d94ce 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -20,6 +20,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
#include "mlir/Dialect/AMX/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
@@ -36,566 +37,6 @@ namespace amx {
#define GEN_PASS_DEF_ENABLEAMXTILEBINDING
#include "mlir/Dialect/AMX/Passes.h.inc"
-//===----------------------------------------------------------------------===//
-// Analysis
-//===----------------------------------------------------------------------===//
-
-/// A class for analyzing (propagating) tile register binding for each tile
-/// vector.
-class TileBindingAnalysis {
-private:
- bool isValidAnalysis;
- DenseMap<Value, int> bindings;
-
-public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
- explicit TileBindingAnalysis(Operation *);
- bool isValid() const { return isValidAnalysis; }
- int getBinding(Value val) const {
- auto iter = bindings.find(val);
- if (iter == bindings.end())
- return -1;
- return iter->second;
- }
- void setBinding(Value val, int index) { bindings[val] = index; }
-};
-
-static bool isTileOp(Operation *op) {
- return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
- llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
- llvm::isa<TileStoreOp>(op);
-}
-
-template <typename Op>
-static bool TileMulCheck(Operation *op) {
- auto tile_mul = dyn_cast_or_null<Op>(op);
- assert(tile_mul);
-
- auto lhsOp = tile_mul.getLhs().getDefiningOp();
- auto rhsOp = tile_mul.getRhs().getDefiningOp();
- auto accOp = tile_mul.getAcc().getDefiningOp();
- if (!isTileOp(lhsOp) || !isTileOp(rhsOp) || !isTileOp(accOp))
- return false;
- return true;
-}
-
-// Not allow mixed use of tile Ops and normal vector Ops, any mixing is
-// considered unacceptable.
-static bool isAcceptableTileOp(Operation *op) {
- if (!isTileOp(op))
- return false;
-
- if (llvm::isa<TileMulFOp>(op)) {
- return TileMulCheck<TileMulFOp>(op);
- } else if (llvm::isa<TileMulIOp>(op)) {
- return TileMulCheck<TileMulIOp>(op);
- } else if (auto tileStore = dyn_cast_or_null<TileStoreOp>(op)) {
- auto valOp = tileStore.getVal().getDefiningOp();
- if (!isTileOp(valOp))
- return false;
- }
- return true;
-}
-
-template <typename Op>
-static bool TileDstPropagate(TileBindingAnalysis *analysis, Operation *op) {
- auto tileDst = dyn_cast_or_null<Op>(op);
- assert(tileDst);
- std::optional<int8_t> tmmIndex = tileDst.getDstRegIndex();
- if (!tmmIndex) {
- return false;
- }
- analysis->setBinding(tileDst.getRes(), *tmmIndex);
- return true;
-}
-
-template <typename Op>
-static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
- auto tileMul = dyn_cast_or_null<Op>(op);
- assert(tileMul);
- auto accVal = tileMul.getAcc();
- auto accIndex = analysis->getBinding(accVal);
- if (accIndex < 0)
- return false;
-
- analysis->setBinding(tileMul.getRes(), accIndex);
- return true;
-}
-
-TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
- isValidAnalysis = false;
- func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
- if (!func)
- return;
-
- isValidAnalysis = true;
- func->walk<WalkOrder::PreOrder>([this](Operation *op) {
- if (!isValidAnalysis)
- return;
- if (!isTileOp(op))
- return;
- if (!isAcceptableTileOp(op)) {
- isValidAnalysis = false;
- return;
- }
-
- if (llvm::isa<TileZeroOp>(op)) {
- if (!TileDstPropagate<TileZeroOp>(this, op)) {
- isValidAnalysis = false;
- return;
- }
- } else if (llvm::isa<TileLoadOp>(op)) {
- if (!TileDstPropagate<TileLoadOp>(this, op)) {
- isValidAnalysis = false;
- return;
- }
- } else if (llvm::isa<TileMulFOp>(op)) {
- if (!TileMulPropagate<TileMulFOp>(this, op)) {
- isValidAnalysis = false;
- return;
- }
- } else if (llvm::isa<TileMulIOp>(op)) {
- if (!TileMulPropagate<TileMulIOp>(this, op)) {
- isValidAnalysis = false;
- return;
- }
- }
- });
-}
-
-// A class for analyzing tile configuration domination (a.k.a. tile scope).
-class TileScopeAnalysis {
-private:
- typedef llvm::iterator_range<Block::iterator> BlockSeg;
- // A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
- // the length should always be 8.
- struct PaletteInfo {
- bool overflow;
- SmallVector<pair<int, int>, 8> palette;
- PaletteInfo() {
- palette.resize(8, {0, 0});
- clear();
- }
- void clear();
- bool isEmpty(int idx) {
- return palette[idx].first == 0 && palette[idx].second == 0;
- }
- void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
- void merge(const PaletteInfo &rhs);
- bool isConflict(const PaletteInfo &rhs);
- };
- struct TileScope {
- // The BlockSeg here is inclusive (containing `end` Op).
- BlockSeg seg;
- PaletteInfo pi;
- TileScope() { clear(); }
- void clear() { pi.clear(); }
- };
-
- bool isValidAnalysis;
- // Storing parallel Ops that would break tile context & scope.
- DenseSet<Operation *> parallelOps;
- // Storing needed palette info for each concerned Op.
- DenseMap<Operation *, PaletteInfo> neededPalette;
- // Storing the usage scope for each concerned tile Op.
- // The BlockSeg here is inclusive (containing `end` Op).
- DenseMap<Operation *, BlockSeg> tileUsage;
- // Storing final tile scope results for injecting tilecfg/tilerelease.
- SmallVector<TileScope, 10> tileScopes;
-
- void addParallelOp(Operation *op) { parallelOps.insert(op); }
- bool isParallelOp(Operation *op) {
- return parallelOps.find(op) == parallelOps.end();
- }
-
- void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
-
- PaletteInfo collectRegionPalette(Region ®ion);
- PaletteInfo collectPalette(Operation *op);
- // Below two functions are the leaf functinos of recursive collection, will
- // actually insert PaletteInfo into map storage.
- PaletteInfo collectPaletteForScf(Operation *op);
- PaletteInfo collectPaletteForTile(Operation *op);
- std::optional<PaletteInfo> getPalette(Operation *op);
- std::optional<PaletteInfo> getPalette(BlockSeg seg);
-
- void doTileScope(Block &block);
- // The input BlockSeg here is exclusive (not containing `end` Op).
- void doTileScope(BlockSeg seg);
- void doTileScope(Operation *op);
-
-public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
- explicit TileScopeAnalysis(Operation *);
- bool isValid() const { return isValidAnalysis; }
-};
-
-void TileScopeAnalysis::PaletteInfo::clear() {
- overflow = false;
- for (int idx = 0; idx < 8; idx++)
- palette[idx] = {0, 0};
-}
-
-void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
- if (overflow || rhs.overflow) {
- overflow = true;
- return;
- }
- for (int idx = 0; idx < 8; idx++) {
- if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
- if (palette[idx].first != rhs.palette[idx].first ||
- palette[idx].second != rhs.palette[idx].second) {
- overflow = true;
- break;
- }
- } else if (!rhs.isEmpty(idx)) {
- palette[idx] = rhs.palette[idx];
- }
- }
-}
-
-bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
- if (overflow || rhs.overflow) {
- return true;
- }
- for (int idx = 0; idx < 8; idx++) {
- if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
- if (palette[idx].first != rhs.palette[idx].first ||
- palette[idx].second != rhs.palette[idx].second) {
- return true;
- }
- }
- }
-}
-
-// Currently we only operate on scf Ops.
-static bool isConcernedControlFlowOp(Operation *op) {
- return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
- llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
- llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
- llvm::isa<scf::WhileOp>(op);
-}
-
-static bool isTileOp(Operation *op) {
- return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
- llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
- llvm::isa<TileStoreOp>(op);
-}
-
-TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
- isValidAnalysis = false;
- func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
- if (!func)
- return;
-
- isValidAnalysis = true;
- // 0. First walk to mark parallel Ops.
- func->walk<WalkOrder::PostOrder>([this](Operation *op) {
- if (!isParallelOp(op))
- return;
-
- if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
- while (op != root) {
- addParallelOp(op);
- op = op->getParentOp();
- }
- }
- });
-
- // 1. Second walk to collect needed palette for each concerned Op.
- collectNeededPalette(root);
-
- // 2. Third walk to analyse usage scope for each tile Op.
- func->walk<WalkOrder::PreOrder>([this](Operation *op) {
- if (!isValidAnalysis)
- return;
- if (!isTileOp(op))
- return;
- Operation *lastUser = nullptr;
- for (auto user : op->getUsers())
- lastUser = user;
- while (lastUser && op->getBlock() != lastUser->getBlock()) {
- lastUser = lastUser->getParentOp();
- if (!lastUser)
- isValidAnalysis = false;
- }
- setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
- });
- if (!isValidAnalysis)
- return;
-
- // 3. Tile scoping for each segmented region in a recursive manner.
- doTileScope(func.getRegion(0).front());
-}
-
-PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
- PaletteInfo pi;
- for (auto op : block.getOps())
- pi.merge(collectPalette(op));
- return pi;
-}
-
-PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
- if (!isValidAnalysis)
- return PaletteInfo();
- if (auto func = dyn_cast_or_null<func::FuncOp>(root))
- // No need to store PaletteInfo for func.
- return collectRegionPalette(func.getRegion(0).front());
-
- auto iter = neededPalette.find(op);
- if (iter != neededPalette.end())
- return iter->second;
-
- // For now, we only concern certain control flow Ops and tile Ops.
- if (isConcernedControlFlowOp(op))
- return collectPaletteForScf(op);
- if (isTileOp(op))
- return collectPaletteForTile(op);
- return PaletteInfo();
-}
-
-PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
- if (!isConcerendScfOp(op))
- return PaletteInfo();
-
- PaletteInfo pi;
- if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
- llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
- pi = collectNeededPalette(op->getRegion(0).front());
- } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
- auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
- auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
- pi.merge(thenPalette);
- pi.merge(elsePalette);
- } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
- pi = collectRegionPalette(indexOp.getDefaultRegion().front());
- for (auto &caseRegion : indexOp.getCaseRegions()) {
- pi.merge(collectRegionPalette(caseRegion.front()));
- }
- } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
- auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
- auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
- pi.merge(beforePalette);
- pi.merge(afterPalette);
- }
- neededPalette[op] = pi;
- return pi;
-}
-
-static inline pair<int, int> getPaletteShape(VectorType type) {
- ArrayRef<int64_t> shape = type.getShape();
- auto elementType = type.getElementType();
- int typeSize;
- if (elementType.isInteger(8))
- typeSize = 1;
- else if (elementType.isBF16())
- typeSize = 2;
- else if (elementType.isInteger(32) || elementType.isF32())
- typeSize = 4;
- else
- assert(false && "Invalid type for palette");
-
- // Palette shape is { rows, colBytes }.
- return {shape[0], shape[1] * typeSize};
-}
-
-PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
- if (!isTileOp(op))
- return PaletteInfo();
-
-#define PROCESS_UNARY_TILE_OP(op, method) \
- auto index = op.method(); \
- if (!index) { \
- isValidAnalysis = false; \
- return PaletteInfo(); \
- } \
- pi.set(*index, getPaletteShape(op.getVectorType()));
-
-#define PROCESS_TRINARY_TILE_OP(op) \
- auto lhsIndex = tileMulFOp.getLhsRegIndex(); \
- auto rhsIndex = tileMulFOp.getRhsRegIndex(); \
- auto accIndex = tileMulFOp.getAccRegIndex(); \
- if (!lhsIndex || !rhsIndex || !accIndex) { \
- isValidAnalysis = false; \
- return PaletteInfo(); \
- } \
- pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType())); \
- pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType())); \
- pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
-
- PaletteInfo pi;
- if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
- PROCESS_UNARY_TILE_OP(tileLoadOp, getDstRegIndex);
- } else if (auto tileMulFOp = dyn_cast<TileMulFOp>(op)) {
- PROCESS_TRINARY_TILE_OP(tileMulFOp);
- } else if (auto tileMulIOp = dyn_cast<TileMulIOp>(op)) {
- PROCESS_TRINARY_TILE_OP(tileMulIOp);
- } else if (auto tileStoreOp = dyn_cast<TileStoreOp>(op)) {
- PROCESS_UNARY_TILE_OP(tileStoreOp, getSrcRegIndex);
- } else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
- PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
- }
- neededPalette[op] = pi;
- return pi;
-}
-
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
- auto iter = neededPalette.find(op);
- if (iter == neededPalette.end()) {
- return std::null_opt;
- }
- return iter->second;
-}
-
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
- bool hasPaletteInfo = false;
- PaletteInfo pi;
- for (Operation &opIns : seg) {
- auto *op = &opIns;
- auto tmpPi = getPalette(&opIns);
- if (tmpPi) {
- hasPaletteInfo = true;
- pi.merge(*tmpPi);
- }
- }
- return hasPaletteInfo ? pi : std::null_opt;
-}
-
-void TileScopeAnalysis::doTileScope(Block &block) {
- doTileScope(BlockSeg(block.begin(), block.end()));
-}
-
-void TileScopeAnalysis::doTileScope(BlockSeg seg) {
- if (!isValidAnalysis)
- return;
- if (seg.empty())
- return;
- SmallVector<BlockSeg, 3> blockSegs;
- SmallVector<Operation *, 3> paraOps;
- auto currBegin = seg.begin();
- for (auto probe = seg.begin(); probe != seg.end(); probe++) {
- Operation *op = &(*probe);
- if (isParallelOp(op) {
- blockSegs.push_back(BlockSeg(currBegin, probe));
- paraOps.push_back(op);
- currBegin = probe;
- currBegin++;
- }
- }
- if (breakers.size()) {
- assert(blockSegs.size() == paraOps.size());
- for (int idx = 0; idx < paraOps.size(); idx++) {
- doTileScope(blockSegs[idx]);
- doTileScope(paraOps[idx]);
- }
- doTileScope(BlockSeg(currBegin, seg.end()));
- return;
- }
-
- // Do tile scope on parallel-free BlockSeg.
- TileScope currScope;
- std::optional<Block::iterator> currSegStart;
- Block::iterator currIter = seg.begin();
- // Traverse BlockSeg and greedily do tile scoping without look ahead.
- while (currIter != seg.end()) {
- Operation *currOp = &(*currIter);
- if (!currSegStart)
- currSegStart = currIter;
-
- Block::iterator nextIterIfMerge;
- std::optional<PaletteInfo> pi = std::null_opt;
- if (isConcernedControlFlowOp(currOp)) {
- pi = getPalette(currOp);
- nextIterIfMerge = currIter;
- nextIterIfMerge++;
- } else if (isTileOp(currOp)) {
- auto iter = tileUsage.find(currOp);
- if (iter == tileUsage.end()) {
- isValidAnalysis = false;
- return;
- }
- pi = getPalette(iter->second);
- if (pi && pi->overflow) {
- // This means the binding info in tile Ops exceeds the hardware
- // capability.
- isValidAnalysis = false;
- return;
- }
- nextIterIfMerge = iter->second.end();
- nextIterIfMerge++;
- }
- if (!pi) {
- currIter++;
- continue;
- }
-
-#define TRY_ADD_PREVIOUS_SCOPE() \
- if (currSegStart && *currSegStart != currIter) { \
- auto prevIter = currIter; \
- prevIter--; \
- currScope.seg = BlockSeg(*currSegStart, prevIter); \
- tileScopes.push_back(currScope); \
- currScope.clear(); \
- currSegStart = std::null_opt; \
- }
-
- if (pi->overflow) {
- // Only scf Ops could go through this possibility.
- TRY_ADD_PREVIOUS_SCOPE();
- doTileScope(currOp);
- currIter++;
- } else {
- if (currScope.pi.isConflict(*pi)) {
- TRY_ADD_PREVIOUS_SCOPE();
- currScope.pi = *pi;
- currSegStart = currIter;
- currIter++;
- } else {
- currScope.pi.merge(*pi);
- currIter = nextIterIfMerge;
- }
- }
- }
-
- TRY_ADD_PREVIOUS_SCOPE();
-}
-
-void TileScopeAnalysis::doTileScope(Operation *op) {
- // This func try to collect tile scopes for a single control flow Op
- // This func is not for tile Ops.
- if (isTileOp(op))
- return;
- // Ops that invoke this func are either parallelOps or scfOps with overflowed
- // paletteInfo, and neither of them can form a tile scope by itself, so we
- // omit checking self-formed tile scope in this func.
- if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
- llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
- auto &block = op->getRegion(0).front();
- doTileScope(BlockSeg(block.begin(), block.end()));
- } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
- auto &ifBlock = op->getThenRegion().front();
- auto &elseBlock = op->getElseRegion().front();
- doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
- doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
- } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
- auto &defaultBlock = indexOp.getDefaultRegion().front();
- doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
- for (auto &caseRegion : indexOp.getCaseRegions()) {
- auto &caseBlock = indexOp.getDefaultRegion().front();
- doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
- }
- } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
- auto &beforeBlock = whileOp.getRegion(0).front();
- auto &afterBlock = whileOp.getRegion(1).front();
- doTileScope(BlockSeg(beforeBlock.begin(), beforeBlock.end()));
- doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
- }
-}
-
-//===----------------------------------------------------------------------===//
-// Pass
-//===----------------------------------------------------------------------===//
-
class TileStoreBindingRewriter : public OpRewritePattern<TileStoreOp> {
private:
TileBindingAnalysis &analysis;
@@ -830,6 +271,7 @@ struct EnableAMXTileBindingPass
builder.setInsertionPointAfter(end);
builder.create<amx::x86_amx_tilerelease_plain>(loc);
}
+ markAnalysesPreserved<TileScopeAnalysis>();
}
};
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 154987303be673..a85a480593aa38 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
>From c6ef334d8f84a017bea0f4ff70f22901ea3f263b Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 2 Jul 2024 01:36:21 -0700
Subject: [PATCH 09/17] fix tileload & tilestore lowering
---
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a85a480593aa38..27754901f55d5b 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -137,7 +137,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
auto dstRegIndex = op.getDstRegIndex();
assert(dstRegIndex && "Incomplete operation attribute for tile binding");
rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64_plain>(
- op, ptr, stride, *dstRegIndex);
+ op, *dstRegIndex, ptr, stride);
return success();
}
@@ -181,7 +181,7 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
auto srcRegIndex = op.getSrcRegIndex();
assert(srcRegIndex && "Incomplete operation attribute for tile binding");
rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64_plain>(
- op, ptr, stride, *srcRegIndex);
+ op, *srcRegIndex, ptr, stride);
return success();
}
>From 0a0f62e47b09834d439596067baf6482afe7ab53 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Wed, 3 Jul 2024 02:16:54 -0700
Subject: [PATCH 10/17] [TBD] fix some compile issues
---
.../Dialect/AMX/Analysis/AMXBindingAnalysis.h | 21 +--
mlir/include/mlir/Dialect/AMX/Transforms.h | 10 +-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 +-
.../AMX/Analysis/AMXBindingAnalysis.cpp | 133 +++++++++++-------
.../AMX/Transforms/EnableAMXTileBinding.cpp | 50 ++-----
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 5 +-
6 files changed, 122 insertions(+), 100 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index 05ff7f3d4a74cb..38e868e910b114 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -35,6 +35,9 @@ class TileBindingAnalysis {
bool isValidAnalysis;
DenseMap<Value, int> bindings;
+ // Ensure that tile operations are not wrapped by out-of-scope operations.
+ bool isViableTileOps(Operation *root);
+
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
explicit TileBindingAnalysis(Operation *);
@@ -56,24 +59,24 @@ class TileScopeAnalysis {
// the length should always be 8.
struct PaletteInfo {
bool overflow;
- SmallVector<pair<int, int>, 8> palette;
+ SmallVector<std::pair<int, int>, 8> palette;
PaletteInfo() {
palette.resize(8, {0, 0});
clear();
}
void clear();
- bool isEmpty(int idx) {
+ bool isEmpty(int idx) const {
return palette[idx].first == 0 && palette[idx].second == 0;
}
- void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
+ void set(int idx, std::pair<int, int> shape) { palette[idx] = shape; }
void merge(const PaletteInfo &rhs);
- bool isConflict(const PaletteInfo &rhs);
+ bool isConflict(const PaletteInfo &rhs) const;
};
struct TileScope {
// The BlockSeg here is inclusive (containing `end` Op).
BlockSeg seg;
PaletteInfo pi;
- TileScope() { clear(); }
+ TileScope() : seg(Block::iterator(), Block::iterator()) { clear(); }
void clear() { pi.clear(); }
};
@@ -93,11 +96,13 @@ class TileScopeAnalysis {
return parallelOps.find(op) == parallelOps.end();
}
- void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
+ void setTileUsage(Operation *op, BlockSeg seg) {
+ tileUsage[op] = std::move(seg);
+ }
- PaletteInfo collectRegionPalette(Region ®ion);
+ PaletteInfo collectBlockPalette(Block &block);
PaletteInfo collectPalette(Operation *op);
- // Below two functions are the leaf functinos of recursive collection, will
+ // Below two functions are the leaf functions of recursive collection, will
// actually insert PaletteInfo into map storage.
PaletteInfo collectPaletteForScf(Operation *op);
PaletteInfo collectPaletteForTile(Operation *op);
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 19608dff6f160f..881c4b460862c2 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -9,6 +9,10 @@
#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H
#define MLIR_DIALECT_AMX_TRANSFORMS_H
+#include <optional>
+
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
+
namespace mlir {
class LLVMConversionTarget;
@@ -17,8 +21,10 @@ class RewritePatternSet;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(
+ LLVMTypeConverter &converter,
+ std::optional<amx::TileScopeAnalysis> &analysis,
+ RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index c3fd9399a6c9e6..8b8b80a22ae041 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -11,6 +11,7 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -105,7 +106,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (amx) {
- auto &analysis = getCachedAnalysis<TileScopeAnalysis>();
+ auto &analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
}
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 8ed992ac498fd9..2c70704ba32597 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -21,12 +21,23 @@
#define DEBUG_TYPE "amx-binding-analysis"
+namespace mlir {
+namespace amx {
+
static bool isTileOp(Operation *op) {
return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
llvm::isa<TileStoreOp>(op);
}
+// Currently we only operate on scf Ops.
+static bool isConcernedControlFlowOp(Operation *op) {
+ return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
+ llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+ llvm::isa<scf::WhileOp>(op);
+}
+
template <typename Op>
static bool TileMulCheck(Operation *op) {
auto tile_mul = dyn_cast_or_null<Op>(op);
@@ -83,11 +94,38 @@ static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
return true;
}
+bool TileBindingAnalysis::isViableTileOps(Operation *root) {
+ auto func = dyn_cast<func::FuncOp>(root);
+ if (!func)
+ return false;
+
+ bool isViable = true;
+ func->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (!isViable)
+ return;
+ if (!isTileOp(op))
+ return;
+ auto probe = op->getParentOp();
+ while (probe != root) {
+ if (!isConcernedControlFlowOp(probe)) {
+ isViable = false;
+ break;
+ }
+ probe = probe->getParentOp();
+ }
+ });
+ return isViable;
+}
+
TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
isValidAnalysis = false;
func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
if (!func)
return;
+ // Ensure that tile operations are not wrapped by out-of-scope Ops, else
+ // cannot do enabling.
+ if (!isViableTileOps(root))
+ return;
isValidAnalysis = true;
func->walk<WalkOrder::PreOrder>([this](Operation *op) {
@@ -148,7 +186,7 @@ void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
}
}
-bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
+bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) const {
if (overflow || rhs.overflow) {
return true;
}
@@ -160,20 +198,7 @@ bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
}
}
}
-}
-
-// Currently we only operate on scf Ops.
-static bool isConcernedControlFlowOp(Operation *op) {
- return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
- llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
- llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
- llvm::isa<scf::WhileOp>(op);
-}
-
-static bool isTileOp(Operation *op) {
- return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
- llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
- llvm::isa<TileStoreOp>(op);
+ return false;
}
TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
@@ -184,7 +209,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
isValidAnalysis = true;
// 0. First walk to mark parallel Ops.
- func->walk<WalkOrder::PostOrder>([this](Operation *op) {
+ func->walk<WalkOrder::PostOrder>([=](Operation *op) {
if (!isParallelOp(op))
return;
@@ -197,10 +222,10 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
});
// 1. Second walk to collect needed palette for each concerned Op.
- collectNeededPalette(root);
+ collectPalette(root);
// 2. Third walk to analyse usage scope for each tile Op.
- func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+ func->walk<WalkOrder::PreOrder>([=](Operation *op) {
if (!isValidAnalysis)
return;
if (!isTileOp(op))
@@ -219,22 +244,28 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
return;
// 3. Tile scoping for each segmented region in a recursive manner.
- doTileScope(func.getRegion(0).front());
+ doTileScope(root->getRegion(0).front());
}
-PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectBlockPalette(Block &block) {
PaletteInfo pi;
- for (auto op : block.getOps())
- pi.merge(collectPalette(op));
+ auto beginIter = block.begin();
+ auto endIter = block.end();
+ for (auto iter = beginIter; iter != endIter; iter++) {
+ auto &opRef = *iter;
+ pi.merge(collectPalette(&opRef));
+ }
return pi;
}
-PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectPalette(Operation *op) {
if (!isValidAnalysis)
return PaletteInfo();
- if (auto func = dyn_cast_or_null<func::FuncOp>(root))
+ if (auto func = dyn_cast_or_null<func::FuncOp>(op))
// No need to store PaletteInfo for func.
- return collectRegionPalette(func.getRegion(0).front());
+ return collectBlockPalette(op->getRegion(0).front());
auto iter = neededPalette.find(op);
if (iter != neededPalette.end())
@@ -248,27 +279,28 @@ PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
return PaletteInfo();
}
-PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
- if (!isConcerendScfOp(op))
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectPaletteForScf(Operation *op) {
+ if (!isConcernedControlFlowOp(op))
return PaletteInfo();
PaletteInfo pi;
if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
- pi = collectNeededPalette(op->getRegion(0).front());
+ pi = collectBlockPalette(op->getRegion(0).front());
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
- auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
- auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
+ auto thenPalette = collectBlockPalette(ifOp.getThenRegion().front());
+ auto elsePalette = collectBlockPalette(ifOp.getElseRegion().front());
pi.merge(thenPalette);
pi.merge(elsePalette);
} else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
- pi = collectRegionPalette(indexOp.getDefaultRegion().front());
+ pi = collectBlockPalette(indexOp.getDefaultRegion().front());
for (auto &caseRegion : indexOp.getCaseRegions()) {
- pi.merge(collectRegionPalette(caseRegion.front()));
+ pi.merge(collectBlockPalette(caseRegion.front()));
}
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
- auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
- auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
+ auto beforePalette = collectBlockPalette(whileOp.getRegion(0).front());
+ auto afterPalette = collectBlockPalette(whileOp.getRegion(1).front());
pi.merge(beforePalette);
pi.merge(afterPalette);
}
@@ -276,7 +308,7 @@ PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
return pi;
}
-static inline pair<int, int> getPaletteShape(VectorType type) {
+static inline std::pair<int, int> getPaletteShape(VectorType type) {
ArrayRef<int64_t> shape = type.getShape();
auto elementType = type.getElementType();
int typeSize;
@@ -293,7 +325,8 @@ static inline pair<int, int> getPaletteShape(VectorType type) {
return {shape[0], shape[1] * typeSize};
}
-PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectPaletteForTile(Operation *op) {
if (!isTileOp(op))
return PaletteInfo();
@@ -315,7 +348,7 @@ PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
} \
pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType())); \
pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType())); \
- pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
+ pi.set(*accIndex, getPaletteShape(op.getVectorType()));
PaletteInfo pi;
if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
@@ -333,26 +366,27 @@ PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
return pi;
}
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
+std::optional<TileScopeAnalysis::PaletteInfo>
+TileScopeAnalysis::getPalette(Operation *op) {
auto iter = neededPalette.find(op);
if (iter == neededPalette.end()) {
- return std::null_opt;
+ return std::nullopt;
}
return iter->second;
}
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
+std::optional<TileScopeAnalysis::PaletteInfo>
+TileScopeAnalysis::getPalette(BlockSeg seg) {
bool hasPaletteInfo = false;
PaletteInfo pi;
for (Operation &opIns : seg) {
- auto *op = &opIns;
auto tmpPi = getPalette(&opIns);
if (tmpPi) {
hasPaletteInfo = true;
pi.merge(*tmpPi);
}
}
- return hasPaletteInfo ? pi : std::null_opt;
+ return hasPaletteInfo ? std::optional<PaletteInfo>{pi} : std::nullopt;
}
void TileScopeAnalysis::doTileScope(Block &block) {
@@ -369,14 +403,14 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
auto currBegin = seg.begin();
for (auto probe = seg.begin(); probe != seg.end(); probe++) {
Operation *op = &(*probe);
- if (isParallelOp(op) {
+ if (isParallelOp(op)) {
blockSegs.push_back(BlockSeg(currBegin, probe));
paraOps.push_back(op);
currBegin = probe;
currBegin++;
}
}
- if (breakers.size()) {
+ if (paraOps.size()) {
assert(blockSegs.size() == paraOps.size());
for (int idx = 0; idx < paraOps.size(); idx++) {
doTileScope(blockSegs[idx]);
@@ -397,7 +431,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
currSegStart = currIter;
Block::iterator nextIterIfMerge;
- std::optional<PaletteInfo> pi = std::null_opt;
+ std::optional<PaletteInfo> pi = std::nullopt;
if (isConcernedControlFlowOp(currOp)) {
pi = getPalette(currOp);
nextIterIfMerge = currIter;
@@ -430,7 +464,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
currScope.seg = BlockSeg(*currSegStart, prevIter); \
tileScopes.push_back(currScope); \
currScope.clear(); \
- currSegStart = std::null_opt; \
+ currSegStart = std::nullopt; \
}
if (pi->overflow) {
@@ -467,8 +501,8 @@ void TileScopeAnalysis::doTileScope(Operation *op) {
auto &block = op->getRegion(0).front();
doTileScope(BlockSeg(block.begin(), block.end()));
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
- auto &ifBlock = op->getThenRegion().front();
- auto &elseBlock = op->getElseRegion().front();
+ auto &ifBlock = ifOp.getThenRegion().front();
+ auto &elseBlock = ifOp.getElseRegion().front();
doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
} else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
@@ -485,3 +519,6 @@ void TileScopeAnalysis::doTileScope(Operation *op) {
doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
}
}
+
+} // namespace amx
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 3a181ed79d94ce..484dadc92618cd 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
#include "mlir/Dialect/AMX/Passes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LogicalResult.h"
@@ -142,38 +143,15 @@ static inline void uint8ArrayToHex(std::string &out, uint8_t array[],
int size) {
llvm::raw_string_ostream os(out);
for (int index = 0; index < size; index++) {
- os << format_hex_no_prefix(array[index], 2, true);
+ os << llvm::format_hex_no_prefix(array[index], 2, true);
}
}
struct EnableAMXTileBindingPass
: public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
private:
- bool isViableTileOps() {
- Operation *root = getOperation();
- auto func = dyn_cast<func::FuncOp>(root);
- if (!func)
- return false;
-
- bool isViable = true;
- func->walk<WalkOrder::PreOrder>([this](Operation *op) {
- if (!isViable)
- return;
- if (!isTileOp(op))
- return;
- auto probe = op->getParentOp();
- while (probe != root) {
- if (!isConcernedControlFlowOp(probe)) {
- isViable = false;
- break;
- }
- probe = probe->getParentOp();
- }
- });
- return isViable;
- }
-
- LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {
+ LLVM::GlobalOp
+ getOrCreateGlobalPalette(const TileScopeAnalysis::PaletteInfo &pi) {
assert(!pi.overflow && "Expecting valid palette");
// Pack struct so it can fit into a single 64-byte cache line.
#pragma pack(push, 1)
@@ -187,7 +165,7 @@ struct EnableAMXTileBindingPass
#pragma pack(pop)
size_t paletteArraySize = 64;
- uint8_t *paletteAsArray = &paletteConfig;
+ auto paletteAsArray = reinterpret_cast<uint8_t *>(&paletteConfig);
memset(paletteAsArray, 0x0, paletteArraySize);
// Intel AMX: The only legal non-INIT value for palette_id is 1.
// TODO(haixin): fetch from CPUID ?
@@ -199,14 +177,15 @@ struct EnableAMXTileBindingPass
}
std::string paletteSymName = "g_intel_amx_palette_";
- uintArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
+ uint8ArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
+ ModuleOp module = getOperation()->template getParentOfType<ModuleOp>();
if ((global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName)))
return global;
// Create a global symbol containing palette config.
- ModuleOp moduleOp = getOperation()->template getParentOfType<ModuleOp>();
- OpBuilder builder(moduleOp);
- builder.setInsertionPointToStart(moduleOp.getBody());
+ auto ctx = module->getContext();
+ OpBuilder builder(module);
+ builder.setInsertionPointToStart(module.getBody());
SmallVector<uint8_t> elementVals;
for (size_t index = 0; index < paletteArraySize; index++)
@@ -218,18 +197,13 @@ struct EnableAMXTileBindingPass
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
auto global = builder.create<LLVM::GlobalOp>(
- moduleOp.getLoc(), arrayType, /*isConstant*/ true,
- LLVM::Linkage::Private, paletteSymName, dataAttr, /*alignment=*/64);
+ module.getLoc(), arrayType, /*isConstant*/ true, LLVM::Linkage::Private,
+ paletteSymName, dataAttr, /*alignment=*/64);
return global;
}
public:
void runOnOperation() override {
- // Ensure that tile Ops are not wrapped by out-of-scope Ops, else cannot do
- // enabling.
- if (!isViableTileOps())
- return;
-
// 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
// tmul & normal vector operations).
TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 27754901f55d5b..9fd116b38ff740 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -212,7 +211,7 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
ConversionPatternRewriter &rewriter) const override {
if (enablingAnalysis && enablingAnalysis->isValid()) {
// Routine for lowering tile Ops with binding info.
- auto lhsRegIndex = op.getSrcRegIndex();
+ auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
auto accRegIndex = op.getAccRegIndex();
@@ -259,7 +258,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
if (enablingAnalysis && enablingAnalysis->isValid()) {
// Routine for lowering tile Ops with binding info.
- auto lhsRegIndex = op.getSrcRegIndex();
+ auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
auto accRegIndex = op.getAccRegIndex();
>From aed71c4c0459d4a3e7d97d00449529b507d86a85 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 5 Jul 2024 02:46:38 -0700
Subject: [PATCH 11/17] fix compile issue
---
mlir/include/mlir/Dialect/AMX/AMX.td | 6 +--
.../Dialect/AMX/Analysis/AMXBindingAnalysis.h | 11 ++--
mlir/include/mlir/Dialect/AMX/Transforms.h | 2 +-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +-
.../AMX/Analysis/AMXBindingAnalysis.cpp | 5 +-
.../AMX/Transforms/EnableAMXTileBinding.cpp | 17 +++---
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 53 +++++++++++--------
7 files changed, 58 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 0c77b8b6ad90fc..0ca220d83ed4fd 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -403,21 +403,21 @@ def LLVM_x86_amx_tdpbssd_plain : AMX_IntrOpBase<"tdpbssd", 0,
Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
// Dot product of i8 tiles into i32 tile (with sign/zero extension).
-def LLVM_x86_amx_tdpbsud_plain : AMX_IntrOpBase<"tdpbsud", 1,
+def LLVM_x86_amx_tdpbsud_plain : AMX_IntrOpBase<"tdpbsud", 0,
[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
// Dot product of i8 tiles into i32 tile (with zero/sign extension).
-def LLVM_x86_amx_tdpbusd_plain : AMX_IntrOpBase<"tdpbusd", 1,
+def LLVM_x86_amx_tdpbusd_plain : AMX_IntrOpBase<"tdpbusd", 0,
[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
// Dot product of i8 tiles into i32 tile (with zero/zero extension).
-def LLVM_x86_amx_tdpbuud_plain : AMX_IntrOpBase<"tdpbuud", 1,
+def LLVM_x86_amx_tdpbuud_plain : AMX_IntrOpBase<"tdpbuud", 0,
[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index 38e868e910b114..e41be602d6cb19 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -53,8 +53,7 @@ class TileBindingAnalysis {
// A class for analyzing tile configuration domination (a.k.a. tile scope).
class TileScopeAnalysis {
-private:
- typedef llvm::iterator_range<Block::iterator> BlockSeg;
+public:
// A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
// the length should always be 8.
struct PaletteInfo {
@@ -72,6 +71,9 @@ class TileScopeAnalysis {
void merge(const PaletteInfo &rhs);
bool isConflict(const PaletteInfo &rhs) const;
};
+
+private:
+ typedef llvm::iterator_range<Block::iterator> BlockSeg;
struct TileScope {
// The BlockSeg here is inclusive (containing `end` Op).
BlockSeg seg;
@@ -97,7 +99,7 @@ class TileScopeAnalysis {
}
void setTileUsage(Operation *op, BlockSeg seg) {
- tileUsage[op] = std::move(seg);
+ tileUsage.insert({op, std::move(seg)});
}
PaletteInfo collectBlockPalette(Block &block);
@@ -117,6 +119,9 @@ class TileScopeAnalysis {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
explicit TileScopeAnalysis(Operation *);
+ const SmallVector<TileScope, 10> &getTileScopes() const {
+ return tileScopes;
+ };
bool isValid() const { return isValidAnalysis; }
};
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 881c4b460862c2..3379dcc7b491b9 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -23,7 +23,7 @@ class RewritePatternSet;
/// intrinsics.
void populateAMXLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter,
- std::optional<amx::TileScopeAnalysis> &analysis,
+ std::optional<std::reference_wrapper<amx::TileScopeAnalysis>> &analysis,
RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 8b8b80a22ae041..0329a5fa4cb648 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -106,7 +106,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (amx) {
- auto &analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
+ auto analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
}
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 2c70704ba32597..519304df6e0295 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
+#include "llvm/Support/Format.h"
#define DEBUG_TYPE "amx-binding-analysis"
@@ -412,7 +413,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
}
if (paraOps.size()) {
assert(blockSegs.size() == paraOps.size());
- for (int idx = 0; idx < paraOps.size(); idx++) {
+ for (size_t idx = 0; idx < paraOps.size(); idx++) {
doTileScope(blockSegs[idx]);
doTileScope(paraOps[idx]);
}
@@ -509,7 +510,7 @@ void TileScopeAnalysis::doTileScope(Operation *op) {
auto &defaultBlock = indexOp.getDefaultRegion().front();
doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
for (auto &caseRegion : indexOp.getCaseRegions()) {
- auto &caseBlock = indexOp.getDefaultRegion().front();
+ auto &caseBlock = caseRegion.front();
doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
}
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 484dadc92618cd..17736c01daa79c 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -27,6 +27,9 @@
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/FormattedStream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -180,7 +183,7 @@ struct EnableAMXTileBindingPass
uint8ArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
ModuleOp module = getOperation()->template getParentOfType<ModuleOp>();
- if ((global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName)))
+ if (auto global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName))
return global;
// Create a global symbol containing palette config.
auto ctx = module->getContext();
@@ -194,7 +197,7 @@ struct EnableAMXTileBindingPass
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
- auto arrayTy =
+ auto arrayType =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
auto global = builder.create<LLVM::GlobalOp>(
module.getLoc(), arrayType, /*isConstant*/ true, LLVM::Linkage::Private,
@@ -212,9 +215,9 @@ struct EnableAMXTileBindingPass
// 1. Set propagated binding info to AMX Ops.
RewritePatternSet patterns(&getContext());
- patterns.add<TileStoreBindingRewriter>(&getContext(), analysis);
- patterns.add<TileMulFBindingRewriter>(&getContext(), analysis);
- patterns.add<TileMulIBindingRewriter>(&getContext(), analysis);
+ patterns.add<TileStoreBindingRewriter>(&getContext(), bindingAna);
+ patterns.add<TileMulFBindingRewriter>(&getContext(), bindingAna);
+ patterns.add<TileMulIBindingRewriter>(&getContext(), bindingAna);
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
@@ -227,13 +230,13 @@ struct EnableAMXTileBindingPass
// 3. Insert tile config/release according to tile scopes.
OpBuilder builder(getOperation());
- for (auto &scope : tileScopes) {
+ for (auto &scope : scopeAna.getTileScopes()) {
assert(!scope.pi.overflow && "Expecting legal AMX palette info");
auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
assert(paletteGlobal && "Failed to create global palette");
Operation *begin = &(*scope.seg.begin());
- Loc loc = begin->getLoc();
+ Location loc = begin->getLoc();
builder.setInsertionPoint(begin);
Value paletteGlobalPtr =
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 9fd116b38ff740..c7f67b49ef0a53 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -75,19 +75,21 @@ Value getStride(ConversionPatternRewriter &rewriter,
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
private:
- const std::optional<TileScopeAnalysis> &enablingAnalysis;
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+ &enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
- TileZeroConversion(const LLVMTypeConverter &typeConverter,
- const std::optional<TileScopeAnalysis> &analysis)
+ TileZeroConversion(
+ const LLVMTypeConverter &typeConverter,
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
: ConvertOpToLLVMPattern<TileZeroOp>(typeConverter),
enablingAnalysis(analysis) {}
LogicalResult
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (enablingAnalysis && enablingAnalysis->isValid()) {
+ if (enablingAnalysis && enablingAnalysis->get().isValid()) {
// Routine for lowering tile Ops with binding info.
auto dstRegIndex = op.getDstRegIndex();
assert(dstRegIndex && "Incomplete operation attribute for tile binding");
@@ -110,12 +112,14 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
private:
- const std::optional<TileScopeAnalysis> &enablingAnalysis;
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+ &enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
- TileLoadConversion(const LLVMTypeConverter &typeConverter,
- const std::optional<TileScopeAnalysis> &analysis)
+ TileLoadConversion(
+ const LLVMTypeConverter &typeConverter,
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
: ConvertOpToLLVMPattern<TileLoadOp>(typeConverter),
enablingAnalysis(analysis) {}
@@ -131,7 +135,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- if (enablingAnalysis && enablingAnalysis->isValid()) {
+ if (enablingAnalysis && enablingAnalysis->get().isValid()) {
// Routine for lowering tile Ops with binding info.
auto dstRegIndex = op.getDstRegIndex();
assert(dstRegIndex && "Incomplete operation attribute for tile binding");
@@ -154,12 +158,14 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
private:
- const std::optional<TileScopeAnalysis> &enablingAnalysis;
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+ &enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
- TileStoreConversion(const LLVMTypeConverter &typeConverter,
- const std::optional<TileScopeAnalysis> &analysis)
+ TileStoreConversion(
+ const LLVMTypeConverter &typeConverter,
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
: ConvertOpToLLVMPattern<TileStoreOp>(typeConverter),
enablingAnalysis(analysis) {}
@@ -175,7 +181,7 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- if (enablingAnalysis && enablingAnalysis->isValid()) {
+ if (enablingAnalysis && enablingAnalysis->get().isValid()) {
// Routine for lowering tile Ops with binding info.
auto srcRegIndex = op.getSrcRegIndex();
assert(srcRegIndex && "Incomplete operation attribute for tile binding");
@@ -197,19 +203,21 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
private:
- const std::optional<TileScopeAnalysis> &enablingAnalysis;
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+ &enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
- TileMulFConversion(const LLVMTypeConverter &typeConverter,
- const std::optional<TileScopeAnalysis> &analysis)
+ TileMulFConversion(
+ const LLVMTypeConverter &typeConverter,
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
: ConvertOpToLLVMPattern<TileMulFOp>(typeConverter),
enablingAnalysis(analysis) {}
LogicalResult
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (enablingAnalysis && enablingAnalysis->isValid()) {
+ if (enablingAnalysis && enablingAnalysis->get().isValid()) {
// Routine for lowering tile Ops with binding info.
auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
@@ -241,12 +249,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
private:
- const std::optional<TileScopeAnalysis> &enablingAnalysis;
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+ &enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
- TileMulIConversion(const LLVMTypeConverter &typeConverter,
- const std::optional<TileScopeAnalysis> &analysis)
+ TileMulIConversion(
+ const LLVMTypeConverter &typeConverter,
+ const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
: ConvertOpToLLVMPattern<TileMulIOp>(typeConverter),
enablingAnalysis(analysis) {}
@@ -256,7 +266,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
bool zexta = op.getIsZextLhs();
bool zextb = op.getIsZextRhs();
- if (enablingAnalysis && enablingAnalysis->isValid()) {
+ if (enablingAnalysis && enablingAnalysis->get().isValid()) {
// Routine for lowering tile Ops with binding info.
auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
@@ -312,7 +322,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter, std::optional<TileScopeAnalysis> &analysis,
+ LLVMTypeConverter &converter,
+ std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis,
RewritePatternSet &patterns) {
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
TileMulFConversion, TileMulIConversion>(converter, analysis);
>From 626064068aaa81bc3142acbd0fe6c913737e7970 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Thu, 1 Aug 2024 00:01:08 -0700
Subject: [PATCH 12/17] [WIP] preliminary bug fixes
---
mlir/include/mlir/Dialect/AMX/AMX.td | 2 +-
.../Dialect/AMX/Analysis/AMXBindingAnalysis.h | 2 +-
mlir/include/mlir/InitAllPasses.h | 2 +
.../AMX/Analysis/AMXBindingAnalysis.cpp | 11 +++++
.../lib/Dialect/AMX/Transforms/CMakeLists.txt | 1 +
.../AMX/Transforms/EnableAMXTileBinding.cpp | 12 +++++-
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 41 +++++++++++--------
7 files changed, 52 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 0ca220d83ed4fd..c8300317c471aa 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -92,7 +92,7 @@ class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
def TileRegisterIndexAttr : OptionalAttr<
- ConfinedAttr<SI8Attr, [IntMinValue<0>, IntMaxValue<7>]>>;
+ ConfinedAttr<I8Attr, [IntMinValue<0>, IntMaxValue<7>]>>;
//
// Tile reset.
diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index e41be602d6cb19..4ddeff63d980a4 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -95,7 +95,7 @@ class TileScopeAnalysis {
void addParallelOp(Operation *op) { parallelOps.insert(op); }
bool isParallelOp(Operation *op) {
- return parallelOps.find(op) == parallelOps.end();
+ return parallelOps.find(op) != parallelOps.end();
}
void setTileUsage(Operation *op, BlockSeg seg) {
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index fedd7737f9ea45..827f220af21152 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -16,6 +16,7 @@
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+#include "mlir/Dialect/AMX/Passes.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
@@ -93,6 +94,7 @@ inline void registerAllPasses() {
arm_sve::registerArmSVEPasses();
emitc::registerEmitCPasses();
xegpu::registerXeGPUPasses();
+ amx::registerAMXPasses();
// Dialect pipelines
bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 519304df6e0295..26ba27ba389054 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
#define DEBUG_TYPE "amx-binding-analysis"
@@ -217,6 +218,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
while (op != root) {
addParallelOp(op);
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> add parallel op: " << op << "\n");
op = op->getParentOp();
}
}
@@ -305,6 +307,7 @@ TileScopeAnalysis::collectPaletteForScf(Operation *op) {
pi.merge(beforePalette);
pi.merge(afterPalette);
}
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> set needed palette for op: " << op << "\n");
neededPalette[op] = pi;
return pi;
}
@@ -363,6 +366,7 @@ TileScopeAnalysis::collectPaletteForTile(Operation *op) {
} else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
}
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> set needed palette for op: " << op << "\n");
neededPalette[op] = pi;
return pi;
}
@@ -397,8 +401,10 @@ void TileScopeAnalysis::doTileScope(Block &block) {
void TileScopeAnalysis::doTileScope(BlockSeg seg) {
if (!isValidAnalysis)
return;
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope A\n");
if (seg.empty())
return;
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope B\n");
SmallVector<BlockSeg, 3> blockSegs;
SmallVector<Operation *, 3> paraOps;
auto currBegin = seg.begin();
@@ -412,6 +418,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
}
}
if (paraOps.size()) {
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope BB\n");
assert(blockSegs.size() == paraOps.size());
for (size_t idx = 0; idx < paraOps.size(); idx++) {
doTileScope(blockSegs[idx]);
@@ -420,6 +427,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
doTileScope(BlockSeg(currBegin, seg.end()));
return;
}
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope C\n");
// Do tile scope on parallel-free BlockSeg.
TileScope currScope;
@@ -428,6 +436,8 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
// Traverse BlockSeg and greedily do tile scoping without look ahead.
while (currIter != seg.end()) {
Operation *currOp = &(*currIter);
+ LLVM_DEBUG(llvm::dbgs()
+ << ">>>>> doTileScope Checking op: " << *currOp << "\n");
if (!currSegStart)
currSegStart = currIter;
@@ -486,6 +496,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
}
}
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> out of doTileScope\n");
TRY_ADD_PREVIOUS_SCOPE();
}
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index 351f788959a6de..1b8feb7f4a544c 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRAMXTransforms
DEPENDS
MLIRAMXConversionsIncGen
+ MLIRAMXTransformsIncGen
MLIRAMXAnalysis
LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 17736c01daa79c..2a37751d387e91 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -213,6 +213,7 @@ struct EnableAMXTileBindingPass
if (!bindingAna.isValid())
return;
+ LLVM_DEBUG(llvm::dbgs() << ">>> After binding analysis\n");
// 1. Set propagated binding info to AMX Ops.
RewritePatternSet patterns(&getContext());
patterns.add<TileStoreBindingRewriter>(&getContext(), bindingAna);
@@ -220,17 +221,26 @@ struct EnableAMXTileBindingPass
patterns.add<TileMulIBindingRewriter>(&getContext(), bindingAna);
FrozenRewritePatternSet patternSet(std::move(patterns));
- if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), patternSet, config)))
return;
+ LLVM_DEBUG(llvm::dbgs() << ">>> After propagating binding info\n");
+
// 2. Analyse tile scopes & expand them maximally.
TileScopeAnalysis &scopeAna = getAnalysis<TileScopeAnalysis>();
if (!scopeAna.isValid())
return;
+ LLVM_DEBUG(llvm::dbgs() << ">>> After tile scope analysis\n");
+
// 3. Insert tile config/release according to tile scopes.
OpBuilder builder(getOperation());
for (auto &scope : scopeAna.getTileScopes()) {
+ LLVM_DEBUG(llvm::dbgs() << ">>> Processing tile scope: "
+ << *scope.seg.begin() << "\n");
assert(!scope.pi.overflow && "Expecting legal AMX palette info");
auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
assert(paletteGlobal && "Failed to create global palette");
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index c7f67b49ef0a53..afff80640bddd4 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -90,11 +90,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+ rewriter.setInsertionPoint(op);
// Routine for lowering tile Ops with binding info.
auto dstRegIndex = op.getDstRegIndex();
assert(dstRegIndex && "Incomplete operation attribute for tile binding");
- rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero_plain>(op,
- *dstRegIndex);
+ rewriter.create<amx::x86_amx_tilezero_plain>(op.getLoc(), *dstRegIndex);
+ rewriter.eraseOp(op);
return success();
}
@@ -136,11 +137,13 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
adaptor.getIndices(), rewriter);
if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+ rewriter.setInsertionPoint(op);
// Routine for lowering tile Ops with binding info.
auto dstRegIndex = op.getDstRegIndex();
assert(dstRegIndex && "Incomplete operation attribute for tile binding");
- rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64_plain>(
- op, *dstRegIndex, ptr, stride);
+ rewriter.create<amx::x86_amx_tileloadd64_plain>(op.getLoc(), *dstRegIndex,
+ ptr, stride);
+ rewriter.eraseOp(op);
return success();
}
@@ -182,11 +185,13 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
adaptor.getIndices(), rewriter);
if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+ rewriter.setInsertionPoint(op);
// Routine for lowering tile Ops with binding info.
auto srcRegIndex = op.getSrcRegIndex();
assert(srcRegIndex && "Incomplete operation attribute for tile binding");
- rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64_plain>(
- op, *srcRegIndex, ptr, stride);
+ rewriter.create<amx::x86_amx_tilestored64_plain>(
+ op.getLoc(), *srcRegIndex, ptr, stride);
+ rewriter.eraseOp(op);
return success();
}
@@ -218,6 +223,7 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+ rewriter.setInsertionPoint(op);
// Routine for lowering tile Ops with binding info.
auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
@@ -225,8 +231,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
"Incomplete operation attribute for tile binding");
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps_plain>(
- op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ rewriter.create<amx::x86_amx_tdpbf16ps_plain>(op.getLoc(), *accRegIndex,
+ *lhsRegIndex, *rhsRegIndex);
+ rewriter.eraseOp(op);
return success();
}
@@ -267,6 +274,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
bool zextb = op.getIsZextRhs();
if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+ rewriter.setInsertionPoint(op);
// Routine for lowering tile Ops with binding info.
auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
@@ -275,17 +283,18 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
"Incomplete operation attribute for tile binding");
if (zexta && zextb)
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud_plain>(
- op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ rewriter.create<amx::x86_amx_tdpbuud_plain>(op.getLoc(), *accRegIndex,
+ *lhsRegIndex, *rhsRegIndex);
else if (zexta && !zextb)
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd_plain>(
- op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ rewriter.create<amx::x86_amx_tdpbusd_plain>(op.getLoc(), *accRegIndex,
+ *lhsRegIndex, *rhsRegIndex);
else if (!zexta && zextb)
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud_plain>(
- op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ rewriter.create<amx::x86_amx_tdpbsud_plain>(op.getLoc(), *accRegIndex,
+ *lhsRegIndex, *rhsRegIndex);
else
- rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd_plain>(
- op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+ rewriter.create<amx::x86_amx_tdpbssd_plain>(op.getLoc(), *accRegIndex,
+ *lhsRegIndex, *rhsRegIndex);
+ rewriter.eraseOp(op);
return success();
}
>From 4438ffcfc3f9584d80d23eae3baaa48b3da9f5d2 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Mon, 5 Aug 2024 02:00:17 -0700
Subject: [PATCH 13/17] succ to run through simple case
---
.../Dialect/AMX/Analysis/AMXBindingAnalysis.h | 8 ++--
mlir/include/mlir/Dialect/AMX/Passes.td | 2 +-
.../AMX/Analysis/AMXBindingAnalysis.cpp | 45 ++++++++++++-------
.../AMX/Transforms/EnableAMXTileBinding.cpp | 8 ++--
4 files changed, 37 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index 4ddeff63d980a4..12c9c8678148ab 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -84,7 +84,7 @@ class TileScopeAnalysis {
bool isValidAnalysis;
// Storing parallel Ops that would break tile context & scope.
- DenseSet<Operation *> parallelOps;
+ DenseSet<Operation *> breakOps;
// Storing needed palette info for each concerned Op.
DenseMap<Operation *, PaletteInfo> neededPalette;
// Storing the usage scope for each concerned tile Op.
@@ -93,10 +93,8 @@ class TileScopeAnalysis {
// Storing final tile scope results for injecting tilecfg/tilerelease.
SmallVector<TileScope, 10> tileScopes;
- void addParallelOp(Operation *op) { parallelOps.insert(op); }
- bool isParallelOp(Operation *op) {
- return parallelOps.find(op) != parallelOps.end();
- }
+ void addBreakOp(Operation *op) { breakOps.insert(op); }
+ bool isBreakOp(Operation *op) { return breakOps.find(op) != breakOps.end(); }
void setTileUsage(Operation *op, BlockSeg seg) {
tileUsage.insert({op, std::move(seg)});
diff --git a/mlir/include/mlir/Dialect/AMX/Passes.td b/mlir/include/mlir/Dialect/AMX/Passes.td
index f3c7800b7d8730..518450da1be99f 100644
--- a/mlir/include/mlir/Dialect/AMX/Passes.td
+++ b/mlir/include/mlir/Dialect/AMX/Passes.td
@@ -19,7 +19,7 @@ def EnableAMXTileBinding
by propagating specified binding information
and automatically configuring harware context
}];
- let dependentDialects = ["func::FuncDialect"];
+ let dependentDialects = ["func::FuncDialect", "amx::AMXDialect", "LLVM::LLVMDialect"];
}
#endif // MLIR_DIALECT_AMX_PASSES_TD
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 26ba27ba389054..68e26db1b04770 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -212,13 +212,14 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
isValidAnalysis = true;
// 0. First walk to mark parallel Ops.
func->walk<WalkOrder::PostOrder>([=](Operation *op) {
- if (!isParallelOp(op))
+ if (isBreakOp(op))
return;
- if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+ if (llvm::isa<func::CallOp>(op) || llvm::isa<func::ReturnOp>(op) ||
+ llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
while (op != root) {
- addParallelOp(op);
- LLVM_DEBUG(llvm::dbgs() << ">>>>> add parallel op: " << op << "\n");
+ addBreakOp(op);
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> add break op: " << *op << "\n");
op = op->getParentOp();
}
}
@@ -233,15 +234,25 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
return;
if (!isTileOp(op))
return;
- Operation *lastUser = nullptr;
- for (auto user : op->getUsers())
- lastUser = user;
- while (lastUser && op->getBlock() != lastUser->getBlock()) {
- lastUser = lastUser->getParentOp();
- if (!lastUser)
- isValidAnalysis = false;
+ if (llvm::isa<TileStoreOp>(op)) {
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> setTileUsage: " << *op << " ;;; " << *op
+ << " ;;; " << *op << "\n");
+ setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(op)));
+
+ } else {
+ Operation *lastUser = nullptr;
+ for (auto user : op->getUsers())
+ lastUser = user;
+ while (lastUser && op->getBlock() != lastUser->getBlock()) {
+ lastUser = lastUser->getParentOp();
+ if (!lastUser)
+ isValidAnalysis = false;
+ }
+ LLVM_DEBUG(llvm::dbgs() << ">>>>> setTileUsage: " << *op << " ;;; " << *op
+ << " ;;; " << *lastUser << "\n");
+ setTileUsage(op,
+ BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
}
- setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
});
if (!isValidAnalysis)
return;
@@ -410,7 +421,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
auto currBegin = seg.begin();
for (auto probe = seg.begin(); probe != seg.end(); probe++) {
Operation *op = &(*probe);
- if (isParallelOp(op)) {
+ if (isBreakOp(op)) {
blockSegs.push_back(BlockSeg(currBegin, probe));
paraOps.push_back(op);
currBegin = probe;
@@ -448,19 +459,19 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
nextIterIfMerge = currIter;
nextIterIfMerge++;
} else if (isTileOp(currOp)) {
- auto iter = tileUsage.find(currOp);
- if (iter == tileUsage.end()) {
+ auto usageIter = tileUsage.find(currOp);
+ if (usageIter == tileUsage.end()) {
isValidAnalysis = false;
return;
}
- pi = getPalette(iter->second);
+ pi = getPalette(usageIter->second);
if (pi && pi->overflow) {
// This means the binding info in tile Ops exceeds the hardware
// capability.
isValidAnalysis = false;
return;
}
- nextIterIfMerge = iter->second.end();
+ nextIterIfMerge = usageIter->second.end();
nextIterIfMerge++;
}
if (!pi) {
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 2a37751d387e91..7ef0e0e449891e 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -180,7 +180,8 @@ struct EnableAMXTileBindingPass
}
std::string paletteSymName = "g_intel_amx_palette_";
- uint8ArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
+ uint8ArrayToHex(paletteSymName, paletteAsArray, 2);
+ uint8ArrayToHex(paletteSymName, paletteAsArray + 16, 48);
ModuleOp module = getOperation()->template getParentOfType<ModuleOp>();
if (auto global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName))
@@ -239,8 +240,9 @@ struct EnableAMXTileBindingPass
// 3. Insert tile config/release according to tile scopes.
OpBuilder builder(getOperation());
for (auto &scope : scopeAna.getTileScopes()) {
- LLVM_DEBUG(llvm::dbgs() << ">>> Processing tile scope: "
- << *scope.seg.begin() << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << ">>> Processing tile scope: " << *scope.seg.begin()
+ << " ;;; " << *scope.seg.end() << "\n");
assert(!scope.pi.overflow && "Expecting legal AMX palette info");
auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
assert(paletteGlobal && "Failed to create global palette");
>From 39d349fcb57e208fb46b8d012756b8695daafbdd Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 13 Aug 2024 00:24:14 -0700
Subject: [PATCH 14/17] fix conversion to LLVM IR
---
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 34 +++++++++++++++----
1 file changed, 28 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index afff80640bddd4..fe9f8cebda4941 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -94,8 +94,14 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
// Routine for lowering tile Ops with binding info.
auto dstRegIndex = op.getDstRegIndex();
assert(dstRegIndex && "Incomplete operation attribute for tile binding");
- rewriter.create<amx::x86_amx_tilezero_plain>(op.getLoc(), *dstRegIndex);
- rewriter.eraseOp(op);
+
+ Location loc = op.getLoc();
+ Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+ loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
+
+ rewriter.create<amx::x86_amx_tilezero_plain>(loc, *dstRegIndex);
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, op.getRes().getType(), dstIndex);
return success();
}
@@ -141,9 +147,15 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
// Routine for lowering tile Ops with binding info.
auto dstRegIndex = op.getDstRegIndex();
assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+
+ Location loc = op.getLoc();
+ Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+ loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
+
rewriter.create<amx::x86_amx_tileloadd64_plain>(op.getLoc(), *dstRegIndex,
ptr, stride);
- rewriter.eraseOp(op);
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, op.getRes().getType(), dstIndex);
return success();
}
@@ -231,9 +243,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
"Incomplete operation attribute for tile binding");
- rewriter.create<amx::x86_amx_tdpbf16ps_plain>(op.getLoc(), *accRegIndex,
+ Location loc = op.getLoc();
+ Value accIndex = rewriter.create<LLVM::ConstantOp>(
+ loc, IntegerType::get(rewriter.getContext(), 8), *accRegIndex);
+
+ rewriter.create<amx::x86_amx_tdpbf16ps_plain>(loc, *accRegIndex,
*lhsRegIndex, *rhsRegIndex);
- rewriter.eraseOp(op);
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, op.getRes().getType(), accIndex);
return success();
}
@@ -282,6 +299,10 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
"Incomplete operation attribute for tile binding");
+ Location loc = op.getLoc();
+ Value accIndex = rewriter.create<LLVM::ConstantOp>(
+ loc, IntegerType::get(rewriter.getContext(), 8), *accRegIndex);
+
if (zexta && zextb)
rewriter.create<amx::x86_amx_tdpbuud_plain>(op.getLoc(), *accRegIndex,
*lhsRegIndex, *rhsRegIndex);
@@ -294,7 +315,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
else
rewriter.create<amx::x86_amx_tdpbssd_plain>(op.getLoc(), *accRegIndex,
*lhsRegIndex, *rhsRegIndex);
- rewriter.eraseOp(op);
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, op.getRes().getType(), accIndex);
return success();
}
>From ae47af275a2bef42da27c622005ce44e19fbbc07 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 16 Aug 2024 02:46:19 -0700
Subject: [PATCH 15/17] fix parameter passing
---
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index fe9f8cebda4941..d8cccbfcbae592 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -76,7 +76,7 @@ Value getStride(ConversionPatternRewriter &rewriter,
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
private:
const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- &enablingAnalysis;
+ enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
@@ -120,7 +120,7 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
private:
const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- &enablingAnalysis;
+ enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
@@ -174,7 +174,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
private:
const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- &enablingAnalysis;
+ enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
@@ -221,7 +221,7 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
private:
const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- &enablingAnalysis;
+ enablingAnalysis;
public:
using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
>From 37b27986fc49f3dd19d271508304a5edc3e5210c Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Thu, 22 Aug 2024 02:36:32 -0700
Subject: [PATCH 16/17] remove dependency between enable-amx-binding and
lowering
---
mlir/include/mlir/Dialect/AMX/Transforms.h | 6 +-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 +-
.../AMX/Transforms/EnableAMXTileBinding.cpp | 1 -
.../AMX/Transforms/LegalizeForLLVMExport.cpp | 131 ++++++------------
4 files changed, 49 insertions(+), 92 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 3379dcc7b491b9..0ce96ede533cf9 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -21,10 +21,8 @@ class RewritePatternSet;
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(
- LLVMTypeConverter &converter,
- std::optional<std::reference_wrapper<amx::TileScopeAnalysis>> &analysis,
- RewritePatternSet &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0329a5fa4cb648..89cfca5d9aba37 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -106,9 +106,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (amx) {
- auto analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
configureAMXLegalizeForExportTarget(target);
- populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
+ populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
}
if (x86Vector) {
configureX86VectorLegalizeForExportTarget(target);
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 7ef0e0e449891e..a271ec5dab57e1 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -260,7 +260,6 @@ struct EnableAMXTileBindingPass
builder.setInsertionPointAfter(end);
builder.create<amx::x86_amx_tilerelease_plain>(loc);
}
- markAnalysesPreserved<TileScopeAnalysis>();
}
};
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index d8cccbfcbae592..e4c04b56b1ceda 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -74,61 +74,46 @@ Value getStride(ConversionPatternRewriter &rewriter,
}
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
-private:
- const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- enablingAnalysis;
-
public:
using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
- TileZeroConversion(
- const LLVMTypeConverter &typeConverter,
- const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
- : ConvertOpToLLVMPattern<TileZeroOp>(typeConverter),
- enablingAnalysis(analysis) {}
+ TileZeroConversion(const LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<TileZeroOp>(typeConverter) {}
LogicalResult
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (enablingAnalysis && enablingAnalysis->get().isValid()) {
- rewriter.setInsertionPoint(op);
- // Routine for lowering tile Ops with binding info.
auto dstRegIndex = op.getDstRegIndex();
- assert(dstRegIndex && "Incomplete operation attribute for tile binding");
-
- Location loc = op.getLoc();
- Value dstIndex = rewriter.create<LLVM::ConstantOp>(
- loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
-
- rewriter.create<amx::x86_amx_tilezero_plain>(loc, *dstRegIndex);
- rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
- op, op.getRes().getType(), dstIndex);
+ if (dstRegIndex) {
+ // Routine for lowering tile Ops with binding info.
+ rewriter.setInsertionPoint(op);
+
+ Location loc = op.getLoc();
+ Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+ loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
+
+ rewriter.create<amx::x86_amx_tilezero_plain>(loc, *dstRegIndex);
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, op.getRes().getType(), dstIndex);
+ return success();
+ }
+
+ VectorType vType = op.getVectorType();
+ // Determine m x n tile sizes.
+ std::pair<Value, Value> tsz =
+ getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ // Replace operation with intrinsic.
+ Type resType = typeConverter->convertType(vType);
+ rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
+ tsz.second);
return success();
- }
-
- VectorType vType = op.getVectorType();
- // Determine m x n tile sizes.
- std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
- // Replace operation with intrinsic.
- Type resType = typeConverter->convertType(vType);
- rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
- tsz.second);
- return success();
}
};
struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
-private:
- const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- enablingAnalysis;
-
public:
using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
- TileLoadConversion(
- const LLVMTypeConverter &typeConverter,
- const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
- : ConvertOpToLLVMPattern<TileLoadOp>(typeConverter),
- enablingAnalysis(analysis) {}
+ TileLoadConversion(const LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<TileLoadOp>(typeConverter) {}
LogicalResult
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
@@ -142,11 +127,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- if (enablingAnalysis && enablingAnalysis->get().isValid()) {
- rewriter.setInsertionPoint(op);
+ auto dstRegIndex = op.getDstRegIndex();
+ if (dstRegIndex) {
// Routine for lowering tile Ops with binding info.
- auto dstRegIndex = op.getDstRegIndex();
- assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+ rewriter.setInsertionPoint(op);
Location loc = op.getLoc();
Value dstIndex = rewriter.create<LLVM::ConstantOp>(
@@ -172,17 +156,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
};
struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
-private:
- const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- enablingAnalysis;
-
public:
using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
- TileStoreConversion(
- const LLVMTypeConverter &typeConverter,
- const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
- : ConvertOpToLLVMPattern<TileStoreOp>(typeConverter),
- enablingAnalysis(analysis) {}
+ TileStoreConversion(const LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<TileStoreOp>(typeConverter) {}
LogicalResult
matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
@@ -196,11 +173,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
- if (enablingAnalysis && enablingAnalysis->get().isValid()) {
- rewriter.setInsertionPoint(op);
+ auto srcRegIndex = op.getSrcRegIndex();
+ if (srcRegIndex) {
// Routine for lowering tile Ops with binding info.
- auto srcRegIndex = op.getSrcRegIndex();
- assert(srcRegIndex && "Incomplete operation attribute for tile binding");
+ rewriter.setInsertionPoint(op);
rewriter.create<amx::x86_amx_tilestored64_plain>(
op.getLoc(), *srcRegIndex, ptr, stride);
rewriter.eraseOp(op);
@@ -219,29 +195,22 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
};
struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
-private:
- const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- enablingAnalysis;
-
public:
using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
- TileMulFConversion(
- const LLVMTypeConverter &typeConverter,
- const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
- : ConvertOpToLLVMPattern<TileMulFOp>(typeConverter),
- enablingAnalysis(analysis) {}
+ TileMulFConversion(const LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<TileMulFOp>(typeConverter) {}
LogicalResult
matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (enablingAnalysis && enablingAnalysis->get().isValid()) {
- rewriter.setInsertionPoint(op);
+ auto accRegIndex = op.getAccRegIndex();
+ if (accRegIndex) {
// Routine for lowering tile Ops with binding info.
+ rewriter.setInsertionPoint(op);
auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
- auto accRegIndex = op.getAccRegIndex();
- assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+ assert(lhsRegIndex && rhsRegIndex &&
"Incomplete operation attribute for tile binding");
Location loc = op.getLoc();
Value accIndex = rewriter.create<LLVM::ConstantOp>(
@@ -272,17 +241,10 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
};
struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
-private:
- const std::optional<std::reference_wrapper<TileScopeAnalysis>>
- &enablingAnalysis;
-
public:
using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
- TileMulIConversion(
- const LLVMTypeConverter &typeConverter,
- const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
- : ConvertOpToLLVMPattern<TileMulIOp>(typeConverter),
- enablingAnalysis(analysis) {}
+ TileMulIConversion(const LLVMTypeConverter &typeConverter)
+ : ConvertOpToLLVMPattern<TileMulIOp>(typeConverter) {}
LogicalResult
matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
@@ -290,14 +252,14 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
bool zexta = op.getIsZextLhs();
bool zextb = op.getIsZextRhs();
- if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+ auto accRegIndex = op.getAccRegIndex();
+ if (accRegIndex) {
rewriter.setInsertionPoint(op);
// Routine for lowering tile Ops with binding info.
auto lhsRegIndex = op.getLhsRegIndex();
auto rhsRegIndex = op.getRhsRegIndex();
- auto accRegIndex = op.getAccRegIndex();
- assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+ assert(lhsRegIndex && rhsRegIndex &&
"Incomplete operation attribute for tile binding");
Location loc = op.getLoc();
Value accIndex = rewriter.create<LLVM::ConstantOp>(
@@ -354,10 +316,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
void mlir::populateAMXLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter,
- std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis,
RewritePatternSet &patterns) {
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
- TileMulFConversion, TileMulIConversion>(converter, analysis);
+ TileMulFConversion, TileMulIConversion>(converter);
}
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
>From d910874aca188488dd0b2a7770c23c963f580cf4 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 3 Sep 2024 00:30:35 -0700
Subject: [PATCH 17/17] tile processing macro minor fix
---
mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 68e26db1b04770..06154b05c0e414 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -354,9 +354,9 @@ TileScopeAnalysis::collectPaletteForTile(Operation *op) {
pi.set(*index, getPaletteShape(op.getVectorType()));
#define PROCESS_TRINARY_TILE_OP(op) \
- auto lhsIndex = tileMulFOp.getLhsRegIndex(); \
- auto rhsIndex = tileMulFOp.getRhsRegIndex(); \
- auto accIndex = tileMulFOp.getAccRegIndex(); \
+ auto lhsIndex = op.getLhsRegIndex(); \
+ auto rhsIndex = op.getRhsRegIndex(); \
+ auto accIndex = op.getAccRegIndex(); \
if (!lhsIndex || !rhsIndex || !accIndex) { \
isValidAnalysis = false; \
return PaletteInfo(); \
More information about the Mlir-commits
mailing list