[Mlir-commits] [mlir] Introduce mlir `AMX` dialect extension (PR #107528)

Haixin Huang llvmlistbot at llvm.org
Fri Sep 6 00:05:21 PDT 2024


https://github.com/huanghaixin008 created https://github.com/llvm/llvm-project/pull/107528

None

>From 454884ba8df56714bf45b122feab3ce180b27707 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 14 Jun 2024 01:57:18 -0700
Subject: [PATCH 01/17] add AMX Ops extension for tmm register binding

---
 mlir/include/mlir/Dialect/AMX/AMX.td          | 139 ++++++++++++++++--
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |   7 +-
 2 files changed, 132 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac..1eeb5c074ca2e6 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -64,16 +64,36 @@ def AMX_Dialect : Dialect {
 class AMX_Op<string mnemonic, list<Trait> traits = []> :
   Op<AMX_Dialect, mnemonic, traits> {}
 
+class AMX_IntrOpBase<string mnemonic, int numResults, 
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = [],
+                    list<Trait> traits = []>
+    : LLVM_IntrOpBase<
+          /*Dialect dialect=*/AMX_Dialect,
+          /*string opName=*/mnemonic,
+          /*string enumName=*/"x86_" # !subst(".", "_", mnemonic),
+          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedOperands=*/[],
+          /*list<Trait> traits=*/traits,
+          /*int numResults=*/numResults,
+          /*bit requiresAccessGroup=*/0,
+          /*bit requiresAliasAnalysis=*/0,
+          /*bit requiresFastmath=*/0,
+          /*list<int> immArgPositions=*/immArgPositions,
+          /*list<string> immArgAttrNames=*/immArgAttrNames>;
+
 // The "internal" intrinsics are meant for compiler usage.
 class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
-  LLVM_IntrOpBase<AMX_Dialect, mnemonic,
-                  "x86_" # !subst(".", "_", mnemonic) # "_internal",
-                  [], [], traits, numResults>;
+  AMX_IntrOp<mnemonic # "_internal", numResults, [], [], traits>;
 
 //===----------------------------------------------------------------------===//
 // AMX Op definitions (user facing).
 //===----------------------------------------------------------------------===//
 
+
+def TileRegisterIndexAttr : OptionalAttr<
+	ConfinedAttr<SI8Attr, [IntMinValue<0>, IntMaxValue<7>]>>;
+
 //
 // Tile reset.
 //
@@ -91,6 +111,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
       %0 = amx.tile_zero : vector<16x16xbf16>
     ```
   }];
+  let arguments = (ins TileRegisterIndexAttr:$dstRegIndex);
   let results = (outs
     VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
@@ -98,7 +119,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "attr-dict `:` type($res)";
+  let assemblyFormat = "(`[` $dstRegIndex^ `]`)? attr-dict `:` type($res)";
   let hasVerifier = 1;
 }
 
@@ -121,7 +142,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
-                   Variadic<Index>:$indices);
+                   Variadic<Index>:$indices, TileRegisterIndexAttr:$srcRegIndex);
   let results = (outs
     VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
@@ -132,7 +153,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
+  let assemblyFormat = "$base `[` $indices `]` attr-dict (`into` `[` $srcRegIndex^ `]`)? `:` "
                        "type($base) `into` type($res)";
   let hasVerifier = 1;
 }
@@ -153,7 +174,8 @@ def TileStoreOp : AMX_Op<"tile_store"> {
   }];
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
-                   VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+                   VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val,
+                   TileRegisterIndexAttr:$dstRegIndex);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -162,7 +184,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
       return ::llvm::cast<VectorType>(getVal().getType());
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
+  let assemblyFormat = "$base `[` $indices `]` `,` $val (`[` $dstRegIndex^ `]`)? attr-dict `:` "
                        "type($base) `,` type($val)";
   let hasVerifier = 1;
 }
@@ -189,7 +211,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
   }];
   let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
                        VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
-                       VectorOfRankAndType<[2], [F32, BF16]>:$acc);
+                       VectorOfRankAndType<[2], [F32, BF16]>:$acc, 
+                       TileRegisterIndexAttr:$lhsRegIndex,
+                       TileRegisterIndexAttr:$rhsRegIndex,
+                       TileRegisterIndexAttr:$accRegIndex);
   let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
   let extraClassDeclaration = [{
     VectorType getLhsVectorType() {
@@ -202,7 +227,9 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
       return ::llvm::cast<VectorType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
+  let assemblyFormat = "$lhs (`[` $lhsRegIndex^ `]`)? `,` "
+                       "$rhs (`[` $rhsRegIndex^ `]`)? `,` "
+                       "$acc (`[` $accRegIndex^ `]`)? attr-dict `:` "
                        "type($lhs) `,` type($rhs) `,` type($acc) ";
   let hasVerifier = 1;
 }
@@ -230,7 +257,10 @@ def TileMulIOp : AMX_Op<"tile_muli", [
                        VectorOfRankAndType<[2], [I32, I8]>:$rhs,
                        VectorOfRankAndType<[2], [I32, I8]>:$acc,
                        UnitAttr:$isZextLhs,
-                       UnitAttr:$isZextRhs
+                       UnitAttr:$isZextRhs,
+                       TileRegisterIndexAttr:$lhsRegIndex,
+                       TileRegisterIndexAttr:$rhsRegIndex,
+                       TileRegisterIndexAttr:$accRegIndex
                        );
   let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
   let extraClassDeclaration = [{
@@ -244,13 +274,15 @@ def TileMulIOp : AMX_Op<"tile_muli", [
       return ::llvm::cast<VectorType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
+  let assemblyFormat = "$lhs (`zext` $isZextLhs^)? (`[` $lhsRegIndex^ `]`)? `,` "
+                       "$rhs (`zext` $isZextRhs^)? (`[` $rhsRegIndex^ `]`)? `,` "
+                       "$acc (`[` $accRegIndex^ `]`)? attr-dict `:` "
                        "type($lhs) `,` type($rhs) `,` type($acc) ";
   let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
-// AMX IntrOp definitions (LLVM compiler facing).
+// AMX Internal IntrOp definitions (LLVM compiler facing).
 //===----------------------------------------------------------------------===//
 
 //
@@ -310,4 +342,85 @@ def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
                  AnyInteger,
 		 AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
 
+//===----------------------------------------------------------------------===//
+// AMX IntrOp definitions (direct mapping to physical assembly).
+// The `dm` suffixes in Op names stand for `direct mapping`.
+//===----------------------------------------------------------------------===//
+
+//
+// Tile palette config operations. Parameter define the pointer to palette membuf.
+//
+
+def LLVM_x86_amx_ldtilecfg_plain : AMX_IntrOpBase<"ldtilecfg", 0>,
+  Arguments<(int LLVM_AnyPointer)>;
+
+def LLVM_x86_amx_tilerelease_plain : AMX_IntrOpBase<"tilerelease", 0>;
+
+//
+// Tile reset. Parameters define the tile size.
+//
+
+def LLVM_x86_amx_tilezero_plain : AMX_IntrOpBase<"tilezero", 0, [0], ["res_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of resulting tmm registers">:$res_index)>;
+
+//
+// Tile memory operations. Parameters define the tile size,
+// base address, and stride between consecutive rows for the
+// memory operation.
+//
+
+def LLVM_x86_amx_tileloadd64_plain : AMX_IntrOpBase<"tileloadd64", 0, [0], ["dst_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+                 LLVM_AnyPointer, AnyInteger)>;
+
+// Non-temporal load version
+def LLVM_x86_amx_tileloaddt164_plain : AMX_IntrOpBase<"tileloaddt164", 0, [0], ["dst_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+                 LLVM_AnyPointer, AnyInteger)>;
+
+def LLVM_x86_amx_tilestored64_plain : AMX_IntrOpBase<"tilestored64", 0, [0], ["src_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of src tmm registers">:$src_index,
+                 LLVM_AnyPointer, AnyInteger)>;
+
+//
+// Tile multiplication operations (series of dot products). Parameters
+// define the tile sizes and source and destination tiles for the
+// operation. Note that the prefix "tdp" stands for tile dot product.
+//
+
+// Dot product of bf16 tiles into f32 tile.
+def LLVM_x86_amx_tdpbf16ps_plain : AMX_IntrOpBase<"tdpbf16ps", 0, 
+		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+                 Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/sign extension).
+def LLVM_x86_amx_tdpbssd_plain : AMX_IntrOpBase<"tdpbssd", 0,
+		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+                 Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with sign/zero extension).
+def LLVM_x86_amx_tdpbsud_plain : AMX_IntrOpBase<"tdpbsud", 1,
+		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+                 Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/sign extension).
+def LLVM_x86_amx_tdpbusd_plain : AMX_IntrOpBase<"tdpbusd", 1,
+		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+                 Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
+// Dot product of i8 tiles into i32 tile (with zero/zero extension).
+def LLVM_x86_amx_tdpbuud_plain : AMX_IntrOpBase<"tdpbuud", 1,
+		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
+  Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
+                 Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
+		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
+
 #endif // AMX
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a8b10f63315d41..bfa3bbc83c80a5 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -211,7 +211,12 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
                     x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
-                    x86_amx_tdpbusd, x86_amx_tdpbuud>();
+                    x86_amx_tdpbusd, x86_amx_tdpbuud, x86_amx_ldtilecfg_plain,
+                    x86_amx_tilerelease_plain, x86_amx_tilezero_plain,
+                    x86_amx_tileloadd64_plain, x86_amx_tilestored64_plain,
+                    x86_amx_tilestoreddt164_plain x86_amx_tdpbf16ps_plain,
+                    x86_amx_tdpbssd_plain, x86_amx_tdpbsud_plain,
+                    x86_amx_tdpbusd_plain, x86_amx_tdpbuud_plain>();
   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
                       TileMulFOp>();
 }

>From 9b0e23a8c3e4b64d9096cf3c851e46d8b4c72575 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Sun, 16 Jun 2024 22:57:12 -0700
Subject: [PATCH 02/17] add basic framework

---
 mlir/include/mlir/Dialect/AMX/AMX.td          |  6 +-
 mlir/include/mlir/Dialect/AMX/CMakeLists.txt  |  7 ++
 mlir/include/mlir/Dialect/AMX/Passes.h        | 37 ++++++++++
 mlir/include/mlir/Dialect/AMX/Passes.td       | 25 +++++++
 .../lib/Dialect/AMX/Transforms/CMakeLists.txt |  1 +
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 68 +++++++++++++++++++
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |  4 +-
 7 files changed, 143 insertions(+), 5 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/AMX/Passes.h
 create mode 100644 mlir/include/mlir/Dialect/AMX/Passes.td
 create mode 100644 mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp

diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 1eeb5c074ca2e6..f5f242d687a0ab 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -84,7 +84,7 @@ class AMX_IntrOpBase<string mnemonic, int numResults,
 
 // The "internal" intrinsics are meant for compiler usage.
 class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
-  AMX_IntrOp<mnemonic # "_internal", numResults, [], [], traits>;
+  AMX_IntrOpBase<mnemonic # "_internal", numResults, [], [], traits>;
 
 //===----------------------------------------------------------------------===//
 // AMX Op definitions (user facing).
@@ -344,7 +344,7 @@ def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
 
 //===----------------------------------------------------------------------===//
 // AMX IntrOp definitions (direct mapping to physical assembly).
-// The `dm` suffixes in Op names stand for `direct mapping`.
+// The `plain` suffixes in Op names stand for plain direct mapping.
 //===----------------------------------------------------------------------===//
 
 //
@@ -352,7 +352,7 @@ def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
 //
 
 def LLVM_x86_amx_ldtilecfg_plain : AMX_IntrOpBase<"ldtilecfg", 0>,
-  Arguments<(int LLVM_AnyPointer)>;
+  Arguments<(ins LLVM_AnyPointer)>;
 
 def LLVM_x86_amx_tilerelease_plain : AMX_IntrOpBase<"tilerelease", 0>;
 
diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
index f3f1aff5a63609..d7a494caa22192 100644
--- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
 set(LLVM_TARGET_DEFINITIONS AMX.td)
 mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRAMXConversionsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name AMX)
+add_public_tablegen_target(MLIRAMXTransformsIncGen)
+add_dependencies(mlir-headers MLIRAMXTransformsIncGen)
+
+add_mlir_doc(Passes AMXPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/AMX/Passes.h b/mlir/include/mlir/Dialect/AMX/Passes.h
new file mode 100644
index 00000000000000..9a3fd458eab380
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Passes.h
@@ -0,0 +1,37 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_PASSES_H
+#define MLIR_DIALECT_AMX_PASSES_H
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+class RewritePatternSet;
+
+namespace amx {
+//===----------------------------------------------------------------------===//
+// The EnableAMXTileBinding pass.
+//===----------------------------------------------------------------------===//
+#define GEN_PASS_DECL
+#include "mlir/Dialect/AMX/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/AMX/Passes.h.inc"
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/AMX/Passes.td b/mlir/include/mlir/Dialect/AMX/Passes.td
new file mode 100644
index 00000000000000..f3c7800b7d8730
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Passes.td
@@ -0,0 +1,25 @@
+//===-- Passes.td - AMX pass definition file ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_PASSES_TD
+#define MLIR_DIALECT_AMX_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+def EnableAMXTileBinding
+    : Pass<"enable-amx-tile-binding", "mlir::func::FuncOp"> {
+  let summary = "Enable AMX tile register binding";
+  let description = [{
+    Enables the AMX tile register binding for each func 
+    by propagating specified binding information 
+    and automatically configuring harware context
+  }];
+  let dependentDialects = ["func::FuncDialect"];
+}
+
+#endif // MLIR_DIALECT_AMX_PASSES_TD
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index 29340d4f45dd1f..a0c1c10f3c1701 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRAMXTransforms
   LegalizeForLLVMExport.cpp
+  EnableAMXTileBinding.cpp
 
   DEPENDS
   MLIRAMXConversionsIncGen
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
new file mode 100644
index 00000000000000..74e25ca4550422
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -0,0 +1,68 @@
+//===- EnableAMXTileBinding.cpp - Enable tile binding for Intel AMX -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass enables the tile register binding semantic for IntelĀ® Advanced
+// Matrix Extensions (IntelĀ® AMX). Intuitively, this pass analyses the tile
+// binding hints set by users, legalize the hints and automatically configures
+// needed hardware context. The AMX tile register usage in lowered intrinsics
+// would strictly respect the given hints, enforced in lowering pass
+// `--convert-vector-to-llvm`.
+//
+// Note that if this pass is not invoked prior to `--convert-vector-to-llvm`,
+// the AMX lowering would ignore the binding info and fallback to original
+// scheme.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+#define DEBUG_TYPE "enable-amx-tile-binding"
+
+namespace mlir {
+namespace amx {
+
+#define GEN_PASS_DEF_ENABLEAMXTILEBINDING
+#include "mlir/Dialect/AMX/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Analysis
+//===----------------------------------------------------------------------===//
+
+/// A class for analyzing tile register binding for each tile vector.
+class TileBindingAnalysis {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
+  explicit TileBindingAnalysis(Operation *);
+};
+
+TileBindingAnalysis::TileBindingAnalysis(Operation *root) {}
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+struct EnableAMXTileBindingPass
+    : public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
+  void runOnOperation() override {
+    // 0. Get AnalyseInfo for each concerned Value (mixed used of tmul & normal
+    // vector operations?)
+    TileBindingAnalysis &analysis = getAnalysis<TileBindingAnalysis>();
+
+    // 1. Propagate binding info to AMX Ops
+    //
+    // 2. Analyse tile scopes & expand them maximally
+    //
+    // 3. insert tile config/release according to tile scopes
+  }
+};
+
+} // namespace amx
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index bfa3bbc83c80a5..5e2aa216cf4123 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -213,8 +213,8 @@ void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
                     x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
                     x86_amx_tdpbusd, x86_amx_tdpbuud, x86_amx_ldtilecfg_plain,
                     x86_amx_tilerelease_plain, x86_amx_tilezero_plain,
-                    x86_amx_tileloadd64_plain, x86_amx_tilestored64_plain,
-                    x86_amx_tilestoreddt164_plain x86_amx_tdpbf16ps_plain,
+                    x86_amx_tileloadd64_plain, x86_amx_tileloaddt164_plain,
+                    x86_amx_tilestored64_plain, x86_amx_tdpbf16ps_plain,
                     x86_amx_tdpbssd_plain, x86_amx_tdpbsud_plain,
                     x86_amx_tdpbusd_plain, x86_amx_tdpbuud_plain>();
   target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,

>From 4554343bae6571d91ecd67c34c780714a8623f3a Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Mon, 24 Jun 2024 02:09:23 -0700
Subject: [PATCH 03/17] add binding info propagation

---
 mlir/include/mlir/Dialect/AMX/AMX.td          |   8 +-
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 244 +++++++++++++++++-
 2 files changed, 242 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index f5f242d687a0ab..0c77b8b6ad90fc 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -142,7 +142,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
     ```
   }];
   let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
-                   Variadic<Index>:$indices, TileRegisterIndexAttr:$srcRegIndex);
+                   Variadic<Index>:$indices, TileRegisterIndexAttr:$dstRegIndex);
   let results = (outs
     VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
   let extraClassDeclaration = [{
@@ -153,7 +153,7 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` attr-dict (`into` `[` $srcRegIndex^ `]`)? `:` "
+  let assemblyFormat = "$base `[` $indices `]` attr-dict (`into` `[` $dstRegIndex^ `]`)? `:` "
                        "type($base) `into` type($res)";
   let hasVerifier = 1;
 }
@@ -175,7 +175,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
   let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
                    Variadic<Index>:$indices,
                    VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val,
-                   TileRegisterIndexAttr:$dstRegIndex);
+                   TileRegisterIndexAttr:$srcRegIndex);
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -184,7 +184,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
       return ::llvm::cast<VectorType>(getVal().getType());
     }
   }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $val (`[` $dstRegIndex^ `]`)? attr-dict `:` "
+  let assemblyFormat = "$base `[` $indices `]` `,` $val (`[` $srcRegIndex^ `]`)? attr-dict `:` "
                        "type($base) `,` type($val)";
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 74e25ca4550422..0308002c27c908 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -21,6 +21,10 @@
 
 #include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/AMX/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 
@@ -36,28 +40,256 @@ namespace amx {
 // Analysis
 //===----------------------------------------------------------------------===//
 
-/// A class for analyzing tile register binding for each tile vector.
+/// A class for analyzing (propagating) tile register binding for each tile
+/// vector.
 class TileBindingAnalysis {
+private:
+  bool isValidAnalysis;
+  DenseMap<Value, int> bindings;
+
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
   explicit TileBindingAnalysis(Operation *);
+  bool isValid() const { return isValidAnalysis; }
+  void setValid(bool v) { isValidAnalysis = v; }
+  int getBinding(Value val) const {
+    auto iter = bindings.find(val);
+    if (iter == bindings.end())
+      return -1;
+    return iter->second;
+  }
+  void setBinding(Value val, int index) { bindings[val] = index; }
 };
 
-TileBindingAnalysis::TileBindingAnalysis(Operation *root) {}
+static bool isTileOp(Operation *op) {
+  return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+         llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+         llvm::isa<TileStoreOp>(op);
+}
+
+template <typename Op>
+static bool TileMulCheck(Operation *op) {
+  auto tile_mul = dyn_cast_or_null<Op>(op);
+  assert(tile_mul);
+
+  auto lhsOp = tile_mul.getLhs().getDefiningOp();
+  auto rhsOp = tile_mul.getRhs().getDefiningOp();
+  auto accOp = tile_mul.getAcc().getDefiningOp();
+  if (!isTileOp(lhsOp) || !isTileOp(rhsOp) || !isTileOp(accOp))
+    return false;
+  return true;
+}
+
+// Not allow mixed use of tile Ops and normal vector Ops, any mixing is
+// considered unacceptable
+static bool isAcceptableTileOp(Operation *op) {
+  if (!isTileOp(op))
+    return false;
+
+  if (llvm::isa<TileMulFOp>(op)) {
+    return TileMulCheck<TileMulFOp>(op);
+  } else if (llvm::isa<TileMulIOp>(op)) {
+    return TileMulCheck<TileMulIOp>(op);
+  } else if (auto tileStore = dyn_cast_or_null<TileStoreOp>(op)) {
+    auto valOp = tileStore.getVal().getDefiningOp();
+    if (!isTileOp(valOp))
+      return false;
+  }
+  return true;
+}
+
+template <typename Op>
+static bool TileDstPropagate(TileBindingAnalysis *analysis, Operation *op) {
+  auto tileDst = dyn_cast_or_null<Op>(op);
+  assert(tileDst);
+  std::optional<int8_t> tmmIndex = tileDst.getDstRegIndex();
+  if (!tmmIndex) {
+    return false;
+  }
+  analysis->setBinding(tileDst.getRes(), *tmmIndex);
+  return true;
+}
+
+template <typename Op>
+static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
+  auto tileMul = dyn_cast_or_null<Op>(op);
+  assert(tileMul);
+  auto accVal = tileMul.getAcc();
+  auto accIndex = analysis->getBinding(accVal);
+  if (accIndex < 0)
+    return false;
+
+  analysis->setBinding(tileMul.getRes(), accIndex);
+  return true;
+}
+
+TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
+  isValidAnalysis = false;
+  func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+  if (!func)
+    return;
+
+  isValidAnalysis = true;
+  func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+    if (!isValidAnalysis)
+      return;
+    if (!isTileOp(op))
+      return;
+    if (!isAcceptableTileOp(op)) {
+      isValidAnalysis = false;
+      return;
+    }
+
+    if (llvm::isa<TileZeroOp>(op)) {
+      if (!TileDstPropagate<TileZeroOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    } else if (llvm::isa<TileLoadOp>(op)) {
+      if (!TileDstPropagate<TileLoadOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    } else if (llvm::isa<TileMulFOp>(op)) {
+      if (!TileMulPropagate<TileMulFOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    } else if (llvm::isa<TileMulIOp>(op)) {
+      if (!TileMulPropagate<TileMulIOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    }
+  });
+}
 
 //===----------------------------------------------------------------------===//
 // Pass
 //===----------------------------------------------------------------------===//
 
+class TileStoreBindingRewriter : public OpRewritePattern<TileStoreOp> {
+private:
+  TileBindingAnalysis &analysis;
+
+public:
+  using OpRewritePattern<TileStoreOp>::OpRewritePattern;
+
+  TileStoreBindingRewriter(MLIRContext *context, TileBindingAnalysis &ana)
+      : OpRewritePattern(context), analysis{ana} {}
+
+  LogicalResult matchAndRewrite(TileStoreOp op,
+                                PatternRewriter &rewriter) const final {
+    auto val = op.getVal();
+    auto srcIndex = analysis.getBinding(val);
+    if (srcIndex < 0)
+      return failure();
+    auto existingAccIndex = op.getSrcRegIndex();
+    if (existingAccIndex && *existingAccIndex != srcIndex)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<TileStoreOp>(
+        op, op.getBase(), op.getIndices(), val,
+        rewriter.getI8IntegerAttr(srcIndex));
+    return success();
+  }
+};
+
+class TileMulFBindingRewriter : public OpRewritePattern<TileMulFOp> {
+private:
+  TileBindingAnalysis &analysis;
+
+public:
+  using OpRewritePattern<TileMulFOp>::OpRewritePattern;
+
+  TileMulFBindingRewriter(MLIRContext *context, TileBindingAnalysis &ana)
+      : OpRewritePattern(context), analysis{ana} {}
+
+  LogicalResult matchAndRewrite(TileMulFOp op,
+                                PatternRewriter &rewriter) const final {
+    auto lhsVal = op.getLhs();
+    auto rhsVal = op.getRhs();
+    auto accVal = op.getAcc();
+    auto lhsIndex = analysis.getBinding(lhsVal);
+    auto rhsIndex = analysis.getBinding(rhsVal);
+    auto accIndex = analysis.getBinding(accVal);
+    if (lhsIndex < 0 || rhsIndex < 0 || accIndex < 0)
+      return failure();
+    auto existingLhsIndex = op.getLhsRegIndex();
+    auto existingRhsIndex = op.getRhsRegIndex();
+    auto existingAccIndex = op.getAccRegIndex();
+    if ((existingLhsIndex && *existingLhsIndex != lhsIndex) ||
+        (existingRhsIndex && *existingRhsIndex != rhsIndex) ||
+        (existingAccIndex && *existingAccIndex != accIndex))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<TileMulFOp>(
+        op, op.getRes().getType(), lhsVal, rhsVal, accVal,
+        rewriter.getI8IntegerAttr(lhsIndex),
+        rewriter.getI8IntegerAttr(rhsIndex),
+        rewriter.getI8IntegerAttr(accIndex));
+    return success();
+  }
+};
+
+class TileMulIBindingRewriter : public OpRewritePattern<TileMulIOp> {
+private:
+  TileBindingAnalysis &analysis;
+
+public:
+  using OpRewritePattern<TileMulIOp>::OpRewritePattern;
+
+  TileMulIBindingRewriter(MLIRContext *context, TileBindingAnalysis &ana)
+      : OpRewritePattern(context), analysis{ana} {}
+
+  LogicalResult matchAndRewrite(TileMulIOp op,
+                                PatternRewriter &rewriter) const final {
+    auto lhsVal = op.getLhs();
+    auto rhsVal = op.getRhs();
+    auto accVal = op.getAcc();
+    auto lhsIndex = analysis.getBinding(lhsVal);
+    auto rhsIndex = analysis.getBinding(rhsVal);
+    auto accIndex = analysis.getBinding(accVal);
+    if (lhsIndex < 0 || rhsIndex < 0 || accIndex < 0)
+      return failure();
+    auto existingLhsIndex = op.getLhsRegIndex();
+    auto existingRhsIndex = op.getRhsRegIndex();
+    auto existingAccIndex = op.getAccRegIndex();
+    if ((existingLhsIndex && *existingLhsIndex != lhsIndex) ||
+        (existingRhsIndex && *existingRhsIndex != rhsIndex) ||
+        (existingAccIndex && *existingAccIndex != accIndex))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<TileMulIOp>(
+        op, op.getRes().getType(), lhsVal, rhsVal, accVal, op.getIsZextLhs(),
+        op.getIsZextRhs(), rewriter.getI8IntegerAttr(lhsIndex),
+        rewriter.getI8IntegerAttr(rhsIndex),
+        rewriter.getI8IntegerAttr(accIndex));
+    return success();
+  }
+};
+
 struct EnableAMXTileBindingPass
     : public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
   void runOnOperation() override {
-    // 0. Get AnalyseInfo for each concerned Value (mixed used of tmul & normal
-    // vector operations?)
+    // 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
+    // tmul & normal vector operations)
     TileBindingAnalysis &analysis = getAnalysis<TileBindingAnalysis>();
+    if (!analysis.isValid())
+      return;
+
+    // 1. Set propagated binding info to AMX Ops
+    RewritePatternSet patterns(&getContext());
+    patterns.add<TileStoreBindingRewriter>(&getContext(), analysis);
+    patterns.add<TileMulFBindingRewriter>(&getContext(), analysis);
+    patterns.add<TileMulIBindingRewriter>(&getContext(), analysis);
+    FrozenRewritePatternSet patternSet(std::move(patterns));
+
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
+      analysis.setValid(false);
+      return;
+    }
 
-    // 1. Propagate binding info to AMX Ops
-    //
     // 2. Analyse tile scopes & expand them maximally
     //
     // 3. insert tile config/release according to tile scopes

>From a44abcfa36b85e70dd26103310b453ba04aa30ef Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Wed, 26 Jun 2024 01:16:17 -0700
Subject: [PATCH 04/17] [WIP] add aux info collection

---
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 111 ++++++++++++++++--
 1 file changed, 104 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 0308002c27c908..0dae119ecce85e 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -51,7 +51,7 @@ class TileBindingAnalysis {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
   explicit TileBindingAnalysis(Operation *);
   bool isValid() const { return isValidAnalysis; }
-  void setValid(bool v) { isValidAnalysis = v; }
+  // void setValid(bool v) { isValidAnalysis = v; }
   int getBinding(Value val) const {
     auto iter = bindings.find(val);
     if (iter == bindings.end())
@@ -164,6 +164,102 @@ TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
   });
 }
 
+// A class for analyzing tile configuration domination (a.k.a. tile scope)
+class TileScopeAnalysis {
+private:
+  typedef llvm::iterator_range<Block::iterator, Block::iterator> BlockSeg;
+  typedef SmallVector<SmallVector<int, 2>, 8> Palette;
+  struct TileScope {
+    BlockSeg seg;
+    Palette palette;
+  };
+
+  bool isValidAnalysis;
+  // Storing Ops that would break tile context & scope (usually parallel Ops)
+  DenseSet<Operation *> scopeBreaker;
+  DenseMap<Operation *, BlockSeg> tileUsage;
+  SmallVector<TileScope, 10> tileScopes;
+
+  void addScopeBreaker(Operation *op) { scopeBreaker.insert(op); }
+  bool isScopeBreaker(Operation *op) {
+    return scopeBreaker.find(op) == scopeBreaker.end();
+  }
+
+  void setTileUsage(Operation *op, BlockSeg seg);
+  BlockSeg getTileUsage();
+  void doTileScope(Block &block);
+  void doTileScope(BlockSeg seg);
+  void doTileScope(Operation *op);
+
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
+  explicit TileScopeAnalysis(Operation *);
+  bool isValid() const { return isValidAnalysis; }
+};
+
+static bool isTileOp(Operation *op) {
+  return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+         llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+         llvm::isa<TileStoreOp>(op);
+}
+
+TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
+  isValidAnalysis = false;
+  func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+  if (!func)
+    return;
+
+  isValidAnalysis = true;
+  // 0. First walk to mark tile scope breakers
+  func->walk<WalkOrder::PostOrder>([this](Operation *op) {
+    if (!isScopeBreaker(op))
+      return;
+
+    if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+        llvm::isa<omp::ParallelOp>(op) || llvm::isa<omp::WsloopOp>(op)) {
+      while (op != root) {
+        addScopeBreaker(op);
+        op = op->getParentOp();
+      }
+    }
+  });
+
+  // 1. Second walk to analyse usage scope for each tile Op
+  func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+    if (!isValidAnalysis)
+      return;
+    if (!isTileOp(op))
+      return;
+    Operation *lastUser = nullptr;
+    for (auto user : op->getUsers())
+      lastUser = user;
+    while (lastUser && op->getBlock() != lastUser->getBlock()) {
+      lastUser = lastUser->getParentOp();
+      if (!lastUser)
+        isValidAnalysis = false;
+    }
+    setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
+  });
+  if (!isValidAnalysis)
+    return;
+
+  // 2. Tile scoping for each segmented region in a recursive manner
+  doTileScope(func.getRegion(0).front());
+}
+
+void TileScopeAnalysis::doTileScope(Block &block) {
+  doTileScope(BlockSeg(block.begin(), block.end()));
+}
+
+void TileScopeAnalysis::doTileScope(BlockSeg seg) {
+  if (seg.empty())
+    return;
+  for (auto probe = seg.begin(); probe < seg.end(); probe++) {
+  }
+}
+
+void TileScopeAnalysis::doTileScope(Operation *op) {}
+
 //===----------------------------------------------------------------------===//
 // Pass
 //===----------------------------------------------------------------------===//
@@ -274,8 +370,8 @@ struct EnableAMXTileBindingPass
   void runOnOperation() override {
     // 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
     // tmul & normal vector operations)
-    TileBindingAnalysis &analysis = getAnalysis<TileBindingAnalysis>();
-    if (!analysis.isValid())
+    TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
+    if (!bindingAna.isValid())
       return;
 
     // 1. Set propagated binding info to AMX Ops
@@ -285,13 +381,14 @@ struct EnableAMXTileBindingPass
     patterns.add<TileMulIBindingRewriter>(&getContext(), analysis);
     FrozenRewritePatternSet patternSet(std::move(patterns));
 
-    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
-      analysis.setValid(false);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
       return;
-    }
 
     // 2. Analyse tile scopes & expand them maximally
-    //
+    TileScopeAnalysis &scopeAna = getAnalysis<TileScopeAnalysis>();
+    if (!scopeAna.isValid())
+      return;
+
     // 3. insert tile config/release according to tile scopes
   }
 };

>From 769b74e3b41cdfff18ba5b089fa26e9ab269dcf6 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 28 Jun 2024 02:35:02 -0700
Subject: [PATCH 05/17] [TBD] add tile scoping algorithm

---
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 336 +++++++++++++++++-
 1 file changed, 319 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 0dae119ecce85e..87459540929853 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -167,27 +167,61 @@ TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
 // A class for analyzing tile configuration domination (a.k.a. tile scope)
 class TileScopeAnalysis {
 private:
-  typedef llvm::iterator_range<Block::iterator, Block::iterator> BlockSeg;
-  typedef SmallVector<SmallVector<int, 2>, 8> Palette;
+  typedef llvm::iterator_range<Block::iterator> BlockSeg;
+  // A list of 2-dim shapes representing tmm register shape, the length should
+  // always be 8
+  struct PaletteInfo {
+    bool overflow;
+    SmallVector<pair<int, int>, 8> palette;
+    PaletteInfo() {
+      palette.resize(8, {0, 0});
+      clear();
+    }
+    void clear();
+    bool isEmpty(int idx) {
+      return palette[idx].first == 0 && palette[idx].second == 0;
+    }
+    void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
+    void merge(const PaletteInfo &rhs);
+    bool isConflict(const PaletteInfo &rhs);
+  };
   struct TileScope {
+    // The BlockSeg here is inclusive (including end Op)
     BlockSeg seg;
-    Palette palette;
+    PaletteInfo pi;
+    TileScope() { clear(); }
+    void clear() { pi.clear(); }
   };
 
   bool isValidAnalysis;
-  // Storing Ops that would break tile context & scope (usually parallel Ops)
-  DenseSet<Operation *> scopeBreaker;
+  // Storing parallel Ops that would break tile context & scope
+  DenseSet<Operation *> parallelOps;
+  // Storing needed palette info for each concerned Op
+  DenseMap<Operation *, PaletteInfo> neededPalette;
+  // Storing the usage scope for each concerned tile Op
+  // The BlockSeg here is inclusive (including end Op)
   DenseMap<Operation *, BlockSeg> tileUsage;
+  // Storing final tile scope results for injecting tilecfg/tilerelease
   SmallVector<TileScope, 10> tileScopes;
 
-  void addScopeBreaker(Operation *op) { scopeBreaker.insert(op); }
-  bool isScopeBreaker(Operation *op) {
-    return scopeBreaker.find(op) == scopeBreaker.end();
+  void addParallelOp(Operation *op) { parallelOps.insert(op); }
+  bool isParallelOp(Operation *op) {
+    return parallelOps.find(op) == parallelOps.end();
   }
 
-  void setTileUsage(Operation *op, BlockSeg seg);
-  BlockSeg getTileUsage();
+  void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
+
+  PaletteInfo collectRegionPalette(Region &region);
+  PaletteInfo collectPalette(Operation *op);
+  // Below two functions are the leaf functinos of recursive collection, will
+  // actually insert PaletteInfo into map storage
+  PaletteInfo collectPaletteForScf(Operation *op);
+  PaletteInfo collectPaletteForTile(Operation *op);
+  std::optional<PaletteInfo> getPalette(Operation *op);
+  std::optional<PaletteInfo> getPalette(BlockSeg seg);
+
   void doTileScope(Block &block);
+  // The BlockSeg here is exclusive (excluding end Op)
   void doTileScope(BlockSeg seg);
   void doTileScope(Operation *op);
 
@@ -197,6 +231,51 @@ class TileScopeAnalysis {
   bool isValid() const { return isValidAnalysis; }
 };
 
+void TileScopeAnalysis::PaletteInfo::clear() {
+  overflow = false;
+  for (int idx = 0; idx < 8; idx++)
+    palette[idx] = {0, 0};
+}
+
+void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
+  if (overflow || rhs.overflow) {
+    overflow = true;
+    return;
+  }
+  for (int idx = 0; idx < 8; idx++) {
+    if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+      if (palette[idx].first != rhs.palette[idx].first ||
+          palette[idx].second != rhs.palette[idx].second) {
+        overflow = true;
+        break;
+      }
+    } else if (!rhs.isEmpty(idx)) {
+      palette[idx] = rhs.palette[idx];
+    }
+  }
+}
+
+bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
+  if (overflow || rhs.overflow) {
+    return true;
+  }
+  for (int idx = 0; idx < 8; idx++) {
+    if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+      if (palette[idx].first != rhs.palette[idx].first ||
+          palette[idx].second != rhs.palette[idx].second) {
+        return true;
+      }
+    }
+  }
+}
+
+static bool isConcernedScfOp(Operation *op) {
+  return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+         llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
+         llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+         llvm::isa<scf::WhileOp>(op);
+}
+
 static bool isTileOp(Operation *op) {
   return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
          llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
@@ -210,21 +289,24 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
     return;
 
   isValidAnalysis = true;
-  // 0. First walk to mark tile scope breakers
+  // 0. First walk to mark parallel Ops
   func->walk<WalkOrder::PostOrder>([this](Operation *op) {
-    if (!isScopeBreaker(op))
+    if (!isParallelOp(op))
       return;
 
     if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
         llvm::isa<omp::ParallelOp>(op) || llvm::isa<omp::WsloopOp>(op)) {
       while (op != root) {
-        addScopeBreaker(op);
+        addParallelOp(op);
         op = op->getParentOp();
       }
     }
   });
 
-  // 1. Second walk to analyse usage scope for each tile Op
+  // 1. Second walk to collect needed palette for each concerned Op
+  collectNeededPalette(root);
+
+  // 2. Third walk to analyse usage scope for each tile Op
   func->walk<WalkOrder::PreOrder>([this](Operation *op) {
     if (!isValidAnalysis)
       return;
@@ -243,22 +325,242 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
   if (!isValidAnalysis)
     return;
 
-  // 2. Tile scoping for each segmented region in a recursive manner
+  // 3. Tile scoping for each segmented region in a recursive manner
   doTileScope(func.getRegion(0).front());
 }
 
+PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
+  PaletteInfo pi;
+  for (auto op : block.getOps())
+    pi.merge(collectPalette(op));
+  return pi;
+}
+
+PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
+  if (!isValidAnalysis)
+    return PaletteInfo();
+  if (auto func = dyn_cast_or_null<func::FuncOp>(root))
+    // No need to store PaletteInfo for func
+    return collectRegionPalette(func.getRegion(0).front());
+
+  auto iter = neededPalette.find(op);
+  if (iter != neededPalette.end())
+    return iter->second;
+
+  // For now, we only concern certain control flow Ops and tile Ops
+  if (isConcernedScfOp(op))
+    return collectPaletteForScf(op);
+  if (isTileOp(op))
+    return collectPaletteForTile(op);
+  return PaletteInfo();
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
+  if (!isConcerendScfOp(op))
+    return PaletteInfo();
+
+  PaletteInfo pi;
+  if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+      llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+    pi = collectNeededPalette(op->getRegion(0).front());
+  } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+    auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
+    auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
+    pi.merge(thenPalette);
+    pi.merge(elsePalette);
+  } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+    pi = collectRegionPalette(indexOp.getDefaultRegion().front());
+    for (auto &caseRegion : indexOp.getCaseRegions()) {
+      pi.merge(collectRegionPalette(caseRegion.front()));
+    }
+  } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+    auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
+    auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
+    pi.merge(beforePalette);
+    pi.merge(afterPalette);
+  }
+  neededPalette[op] = pi;
+  return pi;
+}
+
+static inline pair<int, int> getPaletteShape(VectorType type) {
+  ArrayRef<int64_t> shape = type.getShape();
+  auto elementType = type.getElementType();
+  int typeSize;
+  if (elementType.isInteger(8))
+    typeSize = 1;
+  else if (elementType.isBF16())
+    typeSize = 2;
+  else if (elementType.isInteger(32) || elementType.isF32())
+    typeSize = 4;
+  else
+    assert(false && "Invalid type for palette");
+
+  // Palette shape is { rows, colBytes }
+  return {shape[0], shape[1] * typeSize};
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
+  if (!isTileOp(op))
+    return PaletteInfo();
+
+#define PROCESS_UNARY_TILE_OP(op, method)                                      \
+  auto index = op.method();                                                    \
+  if (!index) {                                                                \
+    isValidAnalysis = false;                                                   \
+    return PaletteInfo();                                                      \
+  }                                                                            \
+  pi.set(*index, getPaletteShape(op.getVectorType()));
+
+#define PROCESS_TRINARY_TILE_OP(op)                                            \
+  auto lhsIndex = tileMulFOp.getLhsRegIndex();                                 \
+  auto rhsIndex = tileMulFOp.getRhsRegIndex();                                 \
+  auto accIndex = tileMulFOp.getAccRegIndex();                                 \
+  if (!lhsIndex || !rhsIndex || !accIndex) {                                   \
+    isValidAnalysis = false;                                                   \
+    return PaletteInfo();                                                      \
+  }                                                                            \
+  pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType()));                   \
+  pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType()));                   \
+  pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
+
+  PaletteInfo pi;
+  if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
+    PROCESS_UNARY_TILE_OP(tileLoadOp, getDstRegIndex);
+  } else if (auto tileMulFOp = dyn_cast<TileMulFOp>(op)) {
+    PROCESS_TRINARY_TILE_OP(tileMulFOp);
+  } else if (auto tileMulIOp = dyn_cast<TileMulIOp>(op)) {
+    PROCESS_TRINARY_TILE_OP(tileMulIOp);
+  } else if (auto tileStoreOp = dyn_cast<TileStoreOp>(op)) {
+    PROCESS_UNARY_TILE_OP(tileStoreOp, getSrcRegIndex);
+  } else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
+    PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
+  }
+  neededPalette[op] = pi;
+  return pi;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
+  auto iter = neededPalette.find(op);
+  if (iter == neededPalette.end()) {
+    return std::null_opt;
+  }
+  return iter->second;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
+  bool hasPaletteInfo = false;
+  PaletteInfo pi;
+  for (Operation &opIns : seg) {
+    auto *op = &opIns;
+    auto tmpPi = getPalette(&opIns);
+    if (tmpPi) {
+      hasPaletteInfo = true;
+      pi.merge(*tmpPi);
+    }
+  }
+  return hasPaletteInfo ? pi : std::null_opt;
+}
+
 void TileScopeAnalysis::doTileScope(Block &block) {
   doTileScope(BlockSeg(block.begin(), block.end()));
 }
 
 void TileScopeAnalysis::doTileScope(BlockSeg seg) {
+  if (!isValidAnalysis)
+    return;
   if (seg.empty())
     return;
-  for (auto probe = seg.begin(); probe < seg.end(); probe++) {
+  SmallVector<BlockSeg, 3> blockSegs;
+  SmallVector<Operation *, 3> paraOps;
+  auto currBegin = seg.begin();
+  for (auto probe = seg.begin(); probe != seg.end(); probe++) {
+    Operation *op = &(*probe);
+    if (isParallelOp(op) {
+      blockSegs.push_back(BlockSeg(currBegin, probe));
+      paraOps.push_back(op);
+      currBegin = probe;
+      currBegin++;
+    }
   }
+  if (breakers.size()) {
+    assert(blockSegs.size() == paraOps.size());
+    for (int idx = 0; idx < paraOps.size(); idx++) {
+      doTileScope(blockSegs[idx]);
+      doTileScope(paraOps[idx]);
+    }
+    doTileScope(BlockSeg(currBegin, seg.end()));
+    return;
+  }
+
+  // Do tile scope on parallel-free BlockSeg
+  TileScope currScope;
+  std::optional<Block::iterator> currSegStart;
+  Block::iterator currIter = seg.begin();
+  // Traverse BlockSeg and greedily do tile scoping without look ahead
+  while (currIter != seg.end()) {
+    Operation *currOp = &(*currIter);
+    if (!currSegStart)
+      currSegStart = currIter;
+
+    Block::iterator nextIterIfMerge;
+    std::optional<PaletteInfo> pi = std::null_opt;
+    if (isConcernedScfOp(currOp)) {
+      pi = getPalette(currOp);
+      nextIterIfMerge = currIter;
+      nextIterIfMerge++;
+    } else if (isTileOp(currOp)) {
+      auto iter = tileUsage.find(currOp);
+      if (iter == tileUsage.end()) {
+        isValidAnalysis = false;
+        return;
+      }
+      pi = getPalette(iter->second);
+      if (pi && pi->overflow) {
+        // This means the binding info exceeds the hardware capability
+        isValidAnalysis = false;
+        return;
+      }
+      nextIterIfMerge = iter->second.end();
+      nextIterIfMerge++;
+    }
+    if (!pi) {
+      currIter++;
+      continue;
+    }
+
+#define ADD_PREVIOUS_SCOPE()                                                   \
+  auto prevIter = currIter;                                                    \
+  prevIter--;                                                                  \
+  currScope.seg = BlockSeg(*currSegStart, prevIter);                           \
+  tileScopes.push_back(currScope);                                             \
+  currScope.clear();                                                           \
+  currSegStart = std::null_opt;
+
+    if (pi->overflow) {
+      // Only scf Ops could go through this possibility
+      if (currSegStart && *currSegStart != currIter) {
+        ADD_PREVIOUS_SCOPE()
+      }
+      doTileScope(currOp);
+      currIter++;
+    } else {
+      if (currScope.pi.isConflict(*pi)) {
+        ADD_PREVIOUS_SCOPE();
+        currIter++;
+      } else {
+        currScope.pi.merge(*pi);
+        currIter = nextIterIfMerge;
+      }
+    }
+  }
+
+  ADD_PREVIOUS_SCOPE();
 }
 
-void TileScopeAnalysis::doTileScope(Operation *op) {}
+void TileScopeAnalysis::doTileScope(Operation *op) {
+  // TODO: do tile scope for scf here
+}
 
 //===----------------------------------------------------------------------===//
 // Pass

>From 986064a75631b610c242887dad759413733bfd7e Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Mon, 1 Jul 2024 02:26:35 -0700
Subject: [PATCH 06/17] finish tile scoping procedure

---
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 129 ++++++++++++++----
 1 file changed, 106 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 87459540929853..24586ac966a728 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -186,7 +186,7 @@ class TileScopeAnalysis {
     bool isConflict(const PaletteInfo &rhs);
   };
   struct TileScope {
-    // The BlockSeg here is inclusive (including end Op)
+    // The BlockSeg here is inclusive (containing `end` Op)
     BlockSeg seg;
     PaletteInfo pi;
     TileScope() { clear(); }
@@ -199,7 +199,7 @@ class TileScopeAnalysis {
   // Storing needed palette info for each concerned Op
   DenseMap<Operation *, PaletteInfo> neededPalette;
   // Storing the usage scope for each concerned tile Op
-  // The BlockSeg here is inclusive (including end Op)
+  // The BlockSeg here is inclusive (containing `end` Op)
   DenseMap<Operation *, BlockSeg> tileUsage;
   // Storing final tile scope results for injecting tilecfg/tilerelease
   SmallVector<TileScope, 10> tileScopes;
@@ -221,7 +221,7 @@ class TileScopeAnalysis {
   std::optional<PaletteInfo> getPalette(BlockSeg seg);
 
   void doTileScope(Block &block);
-  // The BlockSeg here is exclusive (excluding end Op)
+  // The input BlockSeg here is exclusive (not containing `end` Op)
   void doTileScope(BlockSeg seg);
   void doTileScope(Operation *op);
 
@@ -269,7 +269,8 @@ bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
   }
 }
 
-static bool isConcernedScfOp(Operation *op) {
+// Currently we only operate on scf Ops
+static bool isConcernedControlFlowOp(Operation *op) {
   return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
          llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
          llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
@@ -294,8 +295,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
     if (!isParallelOp(op))
       return;
 
-    if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
-        llvm::isa<omp::ParallelOp>(op) || llvm::isa<omp::WsloopOp>(op)) {
+    if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
       while (op != root) {
         addParallelOp(op);
         op = op->getParentOp();
@@ -348,7 +348,7 @@ PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
     return iter->second;
 
   // For now, we only concern certain control flow Ops and tile Ops
-  if (isConcernedScfOp(op))
+  if (isConcernedControlFlowOp(op))
     return collectPaletteForScf(op);
   if (isTileOp(op))
     return collectPaletteForTile(op);
@@ -505,7 +505,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
 
     Block::iterator nextIterIfMerge;
     std::optional<PaletteInfo> pi = std::null_opt;
-    if (isConcernedScfOp(currOp)) {
+    if (isConcernedControlFlowOp(currOp)) {
       pi = getPalette(currOp);
       nextIterIfMerge = currIter;
       nextIterIfMerge++;
@@ -517,7 +517,8 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
       }
       pi = getPalette(iter->second);
       if (pi && pi->overflow) {
-        // This means the binding info exceeds the hardware capability
+        // This means the binding info in tile Ops exceeds the hardware
+        // capability
         isValidAnalysis = false;
         return;
       }
@@ -529,24 +530,26 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
       continue;
     }
 
-#define ADD_PREVIOUS_SCOPE()                                                   \
-  auto prevIter = currIter;                                                    \
-  prevIter--;                                                                  \
-  currScope.seg = BlockSeg(*currSegStart, prevIter);                           \
-  tileScopes.push_back(currScope);                                             \
-  currScope.clear();                                                           \
-  currSegStart = std::null_opt;
+#define TRY_ADD_PREVIOUS_SCOPE()                                               \
+  if (currSegStart && *currSegStart != currIter) {                             \
+    auto prevIter = currIter;                                                  \
+    prevIter--;                                                                \
+    currScope.seg = BlockSeg(*currSegStart, prevIter);                         \
+    tileScopes.push_back(currScope);                                           \
+    currScope.clear();                                                         \
+    currSegStart = std::null_opt;                                              \
+  }
 
     if (pi->overflow) {
       // Only scf Ops could go through this possibility
-      if (currSegStart && *currSegStart != currIter) {
-        ADD_PREVIOUS_SCOPE()
-      }
+      TRY_ADD_PREVIOUS_SCOPE();
       doTileScope(currOp);
       currIter++;
     } else {
       if (currScope.pi.isConflict(*pi)) {
-        ADD_PREVIOUS_SCOPE();
+        TRY_ADD_PREVIOUS_SCOPE();
+        currScope.pi = *pi;
+        currSegStart = currIter;
         currIter++;
       } else {
         currScope.pi.merge(*pi);
@@ -555,11 +558,39 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
     }
   }
 
-  ADD_PREVIOUS_SCOPE();
+  TRY_ADD_PREVIOUS_SCOPE();
 }
 
 void TileScopeAnalysis::doTileScope(Operation *op) {
-  // TODO: do tile scope for scf here
+  // This func try to collect tile scopes for a single control flow Op
+  // This func is not for tile Ops
+  if (isTileOp(op))
+    return;
+  // Ops that invoke this func are either parallelOps or scfOps with overflowed
+  // paletteInfo, and neither of them can form a tile scope by itself, so we
+  // omit checking self-formed tile scope in this func
+  if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+      llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+    auto &block = op->getRegion(0).front();
+    doTileScope(BlockSeg(block.begin(), block.end()));
+  } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+    auto &ifBlock = op->getThenRegion().front();
+    auto &elseBlock = op->getElseRegion().front();
+    doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
+    doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
+  } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+    auto &defaultBlock = indexOp.getDefaultRegion().front();
+    doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
+    for (auto &caseRegion : indexOp.getCaseRegions()) {
+      auto &caseBlock = indexOp.getDefaultRegion().front();
+      doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
+    }
+  } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+    auto &beforeBlock = whileOp.getRegion(0).front();
+    auto &afterBlock = whileOp.getRegion(1).front();
+    doTileScope(BlockSeg(beforeBlock.begin(), beforeBlock.end()));
+    doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
+  }
 }
 
 //===----------------------------------------------------------------------===//
@@ -669,7 +700,40 @@ class TileMulIBindingRewriter : public OpRewritePattern<TileMulIOp> {
 
 struct EnableAMXTileBindingPass
     : public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
+private:
+  bool isViableTileOps() {
+    Operation *root = getOperation();
+    auto func = dyn_cast<func::FuncOp>(root);
+    if (!func)
+      return false;
+
+    bool isViable = true;
+    func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+      if (!isViable)
+        return;
+      if (!isTileOp(op))
+        return;
+      auto probe = op->getParentOp();
+      while (probe != root) {
+        if (!isConcernedControlFlowOp(probe)) {
+          isViable = false;
+          break;
+        }
+        probe = probe->getParentOp();
+      }
+    });
+    return isViable;
+  }
+
+  LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {}
+
+public:
   void runOnOperation() override {
+    // Ensure that tile Ops are not wrapped by out-of-scope Ops, else cannot do
+    // enabling
+    if (!isViableTileOps())
+      return;
+
     // 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
     // tmul & normal vector operations)
     TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
@@ -691,7 +755,26 @@ struct EnableAMXTileBindingPass
     if (!scopeAna.isValid())
       return;
 
-    // 3. insert tile config/release according to tile scopes
+    // 3. Insert tile config/release according to tile scopes
+    OpBuilder builder(getOperation());
+    for (auto &scope : tileScopes) {
+      assert(!scope.pi.overflow && "Expecting legal AMX palette info");
+      auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
+      assert(paletteGlobal && "Failed to create global palette");
+
+      Operation *begin = &(*scope.seg.begin());
+      Loc loc = begin->getLoc();
+
+      builder.setInsertionPoint(begin);
+      Value paletteGlobalPtr =
+          builder.create<LLVM::AddressOfOp>(loc, paletteGlobal);
+      builder.create<amx::x86_amx_ldtilecfg_plain>(loc, paletteGlobalPtr);
+
+      Operation *end = &(*scope.seg.end());
+      loc = end->getLoc();
+      builder.setInsertionPointAfter(end);
+      builder.create<amx::x86_amx_tilerelease_plain>(loc);
+    }
   }
 };
 

>From 5faad5a944534894f24eb1d95be070aa1181a040 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 2 Jul 2024 00:25:42 -0700
Subject: [PATCH 07/17] add tile Ops lowering with binding

---
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   3 +-
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 121 ++++++++++-----
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 138 ++++++++++++++++--
 3 files changed, 214 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 55143d5939ba25..c3fd9399a6c9e6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -105,8 +105,9 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
   if (amx) {
+    auto &analysis = getCachedAnalysis<TileScopeAnalysis>();
     configureAMXLegalizeForExportTarget(target);
-    populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
+    populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
   }
   if (x86Vector) {
     configureX86VectorLegalizeForExportTarget(target);
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 24586ac966a728..54a609fa783e04 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -51,7 +51,6 @@ class TileBindingAnalysis {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
   explicit TileBindingAnalysis(Operation *);
   bool isValid() const { return isValidAnalysis; }
-  // void setValid(bool v) { isValidAnalysis = v; }
   int getBinding(Value val) const {
     auto iter = bindings.find(val);
     if (iter == bindings.end())
@@ -81,7 +80,7 @@ static bool TileMulCheck(Operation *op) {
 }
 
 // Not allow mixed use of tile Ops and normal vector Ops, any mixing is
-// considered unacceptable
+// considered unacceptable.
 static bool isAcceptableTileOp(Operation *op) {
   if (!isTileOp(op))
     return false;
@@ -164,12 +163,12 @@ TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
   });
 }
 
-// A class for analyzing tile configuration domination (a.k.a. tile scope)
+// A class for analyzing tile configuration domination (a.k.a. tile scope).
 class TileScopeAnalysis {
 private:
   typedef llvm::iterator_range<Block::iterator> BlockSeg;
-  // A list of 2-dim shapes representing tmm register shape, the length should
-  // always be 8
+  // A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
+  // the length should always be 8.
   struct PaletteInfo {
     bool overflow;
     SmallVector<pair<int, int>, 8> palette;
@@ -186,7 +185,7 @@ class TileScopeAnalysis {
     bool isConflict(const PaletteInfo &rhs);
   };
   struct TileScope {
-    // The BlockSeg here is inclusive (containing `end` Op)
+    // The BlockSeg here is inclusive (containing `end` Op).
     BlockSeg seg;
     PaletteInfo pi;
     TileScope() { clear(); }
@@ -194,14 +193,14 @@ class TileScopeAnalysis {
   };
 
   bool isValidAnalysis;
-  // Storing parallel Ops that would break tile context & scope
+  // Storing parallel Ops that would break tile context & scope.
   DenseSet<Operation *> parallelOps;
-  // Storing needed palette info for each concerned Op
+  // Storing needed palette info for each concerned Op.
   DenseMap<Operation *, PaletteInfo> neededPalette;
-  // Storing the usage scope for each concerned tile Op
-  // The BlockSeg here is inclusive (containing `end` Op)
+  // Storing the usage scope for each concerned tile Op.
+  // The BlockSeg here is inclusive (containing `end` Op).
   DenseMap<Operation *, BlockSeg> tileUsage;
-  // Storing final tile scope results for injecting tilecfg/tilerelease
+  // Storing final tile scope results for injecting tilecfg/tilerelease.
   SmallVector<TileScope, 10> tileScopes;
 
   void addParallelOp(Operation *op) { parallelOps.insert(op); }
@@ -214,14 +213,14 @@ class TileScopeAnalysis {
   PaletteInfo collectRegionPalette(Region &region);
   PaletteInfo collectPalette(Operation *op);
   // Below two functions are the leaf functinos of recursive collection, will
-  // actually insert PaletteInfo into map storage
+  // actually insert PaletteInfo into map storage.
   PaletteInfo collectPaletteForScf(Operation *op);
   PaletteInfo collectPaletteForTile(Operation *op);
   std::optional<PaletteInfo> getPalette(Operation *op);
   std::optional<PaletteInfo> getPalette(BlockSeg seg);
 
   void doTileScope(Block &block);
-  // The input BlockSeg here is exclusive (not containing `end` Op)
+  // The input BlockSeg here is exclusive (not containing `end` Op).
   void doTileScope(BlockSeg seg);
   void doTileScope(Operation *op);
 
@@ -269,7 +268,7 @@ bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
   }
 }
 
-// Currently we only operate on scf Ops
+// Currently we only operate on scf Ops.
 static bool isConcernedControlFlowOp(Operation *op) {
   return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
          llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
@@ -290,7 +289,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
     return;
 
   isValidAnalysis = true;
-  // 0. First walk to mark parallel Ops
+  // 0. First walk to mark parallel Ops.
   func->walk<WalkOrder::PostOrder>([this](Operation *op) {
     if (!isParallelOp(op))
       return;
@@ -303,10 +302,10 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
     }
   });
 
-  // 1. Second walk to collect needed palette for each concerned Op
+  // 1. Second walk to collect needed palette for each concerned Op.
   collectNeededPalette(root);
 
-  // 2. Third walk to analyse usage scope for each tile Op
+  // 2. Third walk to analyse usage scope for each tile Op.
   func->walk<WalkOrder::PreOrder>([this](Operation *op) {
     if (!isValidAnalysis)
       return;
@@ -325,7 +324,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
   if (!isValidAnalysis)
     return;
 
-  // 3. Tile scoping for each segmented region in a recursive manner
+  // 3. Tile scoping for each segmented region in a recursive manner.
   doTileScope(func.getRegion(0).front());
 }
 
@@ -340,14 +339,14 @@ PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
   if (!isValidAnalysis)
     return PaletteInfo();
   if (auto func = dyn_cast_or_null<func::FuncOp>(root))
-    // No need to store PaletteInfo for func
+    // No need to store PaletteInfo for func.
     return collectRegionPalette(func.getRegion(0).front());
 
   auto iter = neededPalette.find(op);
   if (iter != neededPalette.end())
     return iter->second;
 
-  // For now, we only concern certain control flow Ops and tile Ops
+  // For now, we only concern certain control flow Ops and tile Ops.
   if (isConcernedControlFlowOp(op))
     return collectPaletteForScf(op);
   if (isTileOp(op))
@@ -396,7 +395,7 @@ static inline pair<int, int> getPaletteShape(VectorType type) {
   else
     assert(false && "Invalid type for palette");
 
-  // Palette shape is { rows, colBytes }
+  // Palette shape is { rows, colBytes }.
   return {shape[0], shape[1] * typeSize};
 }
 
@@ -493,11 +492,11 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
     return;
   }
 
-  // Do tile scope on parallel-free BlockSeg
+  // Do tile scope on parallel-free BlockSeg.
   TileScope currScope;
   std::optional<Block::iterator> currSegStart;
   Block::iterator currIter = seg.begin();
-  // Traverse BlockSeg and greedily do tile scoping without look ahead
+  // Traverse BlockSeg and greedily do tile scoping without look ahead.
   while (currIter != seg.end()) {
     Operation *currOp = &(*currIter);
     if (!currSegStart)
@@ -518,7 +517,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
       pi = getPalette(iter->second);
       if (pi && pi->overflow) {
         // This means the binding info in tile Ops exceeds the hardware
-        // capability
+        // capability.
         isValidAnalysis = false;
         return;
       }
@@ -541,7 +540,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
   }
 
     if (pi->overflow) {
-      // Only scf Ops could go through this possibility
+      // Only scf Ops could go through this possibility.
       TRY_ADD_PREVIOUS_SCOPE();
       doTileScope(currOp);
       currIter++;
@@ -563,12 +562,12 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
 
 void TileScopeAnalysis::doTileScope(Operation *op) {
   // This func try to collect tile scopes for a single control flow Op
-  // This func is not for tile Ops
+  // This func is not for tile Ops.
   if (isTileOp(op))
     return;
   // Ops that invoke this func are either parallelOps or scfOps with overflowed
   // paletteInfo, and neither of them can form a tile scope by itself, so we
-  // omit checking self-formed tile scope in this func
+  // omit checking self-formed tile scope in this func.
   if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
       llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
     auto &block = op->getRegion(0).front();
@@ -698,6 +697,14 @@ class TileMulIBindingRewriter : public OpRewritePattern<TileMulIOp> {
   }
 };
 
+static inline void uint8ArrayToHex(std::string &out, uint8_t array[],
+                                   int size) {
+  llvm::raw_string_ostream os(out);
+  for (int index = 0; index < size; index++) {
+    os << format_hex_no_prefix(array[index], 2, true);
+  }
+}
+
 struct EnableAMXTileBindingPass
     : public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
 private:
@@ -725,22 +732,70 @@ struct EnableAMXTileBindingPass
     return isViable;
   }
 
-  LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {}
+  LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {
+    assert(!pi.overflow && "Expecting valid palette");
+// Pack struct so it can fit into a single 64-byte cache line.
+#pragma pack(push, 1)
+    struct {
+      uint8_t paletteId;
+      uint8_t startRow;
+      uint8_t reserved[14];
+      uint16_t cols[16];
+      uint8_t rows[16];
+    } paletteConfig;
+#pragma pack(pop)
+
+    size_t paletteArraySize = 64;
+    uint8_t *paletteAsArray = &paletteConfig;
+    memset(paletteAsArray, 0x0, paletteArraySize);
+    // Intel AMX: The only legal non-INIT value for palette_id is 1.
+    // TODO(haixin): fetch from CPUID ?
+    paletteConfig.paletteId = 1;
+    for (int index = 0; index < 8; index++) {
+      const auto &regShape = pi.palette[index];
+      paletteConfig.rows[index] = regShape.first;
+      paletteConfig.cols[index] = regShape.second;
+    }
+
+    std::string paletteSymName = "g_intel_amx_palette_";
+    uintArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
+
+    if ((global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName)))
+      return global;
+    // Create a global symbol containing palette config.
+    ModuleOp moduleOp = getOperation()->template getParentOfType<ModuleOp>();
+    OpBuilder builder(moduleOp);
+    builder.setInsertionPointToStart(moduleOp.getBody());
+
+    SmallVector<uint8_t> elementVals;
+    for (size_t index = 0; index < paletteArraySize; index++)
+      elementVals.push_back(paletteAsArray[index]);
+    auto dataAttrType = RankedTensorType::get(
+        {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+    auto dataAttr =
+        DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+    auto arrayTy =
+        LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+    auto global = builder.create<LLVM::GlobalOp>(
+        moduleOp.getLoc(), arrayType, /*isConstant*/ true,
+        LLVM::Linkage::Private, paletteSymName, dataAttr, /*alignment=*/64);
+    return global;
+  }
 
 public:
   void runOnOperation() override {
     // Ensure that tile Ops are not wrapped by out-of-scope Ops, else cannot do
-    // enabling
+    // enabling.
     if (!isViableTileOps())
       return;
 
     // 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
-    // tmul & normal vector operations)
+    // tmul & normal vector operations).
     TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
     if (!bindingAna.isValid())
       return;
 
-    // 1. Set propagated binding info to AMX Ops
+    // 1. Set propagated binding info to AMX Ops.
     RewritePatternSet patterns(&getContext());
     patterns.add<TileStoreBindingRewriter>(&getContext(), analysis);
     patterns.add<TileMulFBindingRewriter>(&getContext(), analysis);
@@ -750,12 +805,12 @@ struct EnableAMXTileBindingPass
     if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
       return;
 
-    // 2. Analyse tile scopes & expand them maximally
+    // 2. Analyse tile scopes & expand them maximally.
     TileScopeAnalysis &scopeAna = getAnalysis<TileScopeAnalysis>();
     if (!scopeAna.isValid())
       return;
 
-    // 3. Insert tile config/release according to tile scopes
+    // 3. Insert tile config/release according to tile scopes.
     OpBuilder builder(getOperation());
     for (auto &scope : tileScopes) {
       assert(!scope.pi.overflow && "Expecting legal AMX palette info");
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 5e2aa216cf4123..154987303be673 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -74,10 +74,28 @@ Value getStride(ConversionPatternRewriter &rewriter,
 }
 
 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
+private:
+  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
   using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
+  TileZeroConversion(const LLVMTypeConverter &typeConverter,
+                     const std::optional<TileScopeAnalysis> &analysis)
+      : ConvertOpToLLVMPattern<TileZeroOp>(typeConverter),
+        enablingAnalysis(analysis) {}
+
   LogicalResult
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (enablingAnalysis && enablingAnalysis->isValid()) {
+      // Routine for lowering tile Ops with binding info.
+      auto dstRegIndex = op.getDstRegIndex();
+      assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero_plain>(op,
+                                                               *dstRegIndex);
+      return success();
+    }
+
     VectorType vType = op.getVectorType();
     // Determine m x n tile sizes.
     std::pair<Value, Value> tsz =
@@ -91,24 +109,42 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
 };
 
 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
+private:
+  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
   using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
+  TileLoadConversion(const LLVMTypeConverter &typeConverter,
+                     const std::optional<TileScopeAnalysis> &analysis)
+      : ConvertOpToLLVMPattern<TileLoadOp>(typeConverter),
+        enablingAnalysis(analysis) {}
 
   LogicalResult
   matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
-    // Determine m x n tile sizes.
-    std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
     // Determine stride.
     if (failed(verifyStride(mType)))
       return failure();
     Value stride = getStride(rewriter, *getTypeConverter(), mType,
                              adaptor.getBase(), op.getLoc());
-    // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
+
+    if (enablingAnalysis && enablingAnalysis->isValid()) {
+      // Routine for lowering tile Ops with binding info.
+      auto dstRegIndex = op.getDstRegIndex();
+      assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64_plain>(
+          op, ptr, stride, *dstRegIndex);
+      return success();
+    }
+
+    VectorType vType = op.getVectorType();
+    // Determine m x n tile sizes.
+    std::pair<Value, Value> tsz =
+        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+    // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(vType);
     rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
         op, resType, tsz.first, tsz.second, ptr, stride);
@@ -117,24 +153,42 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
 };
 
 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
+private:
+  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
   using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+  TileStoreConversion(const LLVMTypeConverter &typeConverter,
+                      const std::optional<TileScopeAnalysis> &analysis)
+      : ConvertOpToLLVMPattern<TileStoreOp>(typeConverter),
+        enablingAnalysis(analysis) {}
 
   LogicalResult
   matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     MemRefType mType = op.getMemRefType();
-    VectorType vType = op.getVectorType();
-    // Determine m x n tile sizes.
-    std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
     // Determine stride.
     if (failed(verifyStride(mType)))
       return failure();
     Value stride = getStride(rewriter, *getTypeConverter(), mType,
                              adaptor.getBase(), op.getLoc());
-    // Replace operation with intrinsic.
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
+
+    if (enablingAnalysis && enablingAnalysis->isValid()) {
+      // Routine for lowering tile Ops with binding info.
+      auto srcRegIndex = op.getSrcRegIndex();
+      assert(srcRegIndex && "Incomplete operation attribute for tile binding");
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64_plain>(
+          op, ptr, stride, *srcRegIndex);
+      return success();
+    }
+
+    VectorType vType = op.getVectorType();
+    // Determine m x n tile sizes.
+    std::pair<Value, Value> tsz =
+        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+    // Replace operation with intrinsic.
     rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
         op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
     return success();
@@ -142,10 +196,32 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
 };
 
 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
+private:
+  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
   using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
+  TileMulFConversion(const LLVMTypeConverter &typeConverter,
+                     const std::optional<TileScopeAnalysis> &analysis)
+      : ConvertOpToLLVMPattern<TileMulFOp>(typeConverter),
+        enablingAnalysis(analysis) {}
+
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    if (enablingAnalysis && enablingAnalysis->isValid()) {
+      // Routine for lowering tile Ops with binding info.
+      auto lhsRegIndex = op.getSrcRegIndex();
+      auto rhsRegIndex = op.getRhsRegIndex();
+      auto accRegIndex = op.getAccRegIndex();
+
+      assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+             "Incomplete operation attribute for tile binding");
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps_plain>(
+          op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+      return success();
+    }
+
     VectorType aType = op.getLhsVectorType();
     VectorType bType = op.getRhsVectorType();
     VectorType cType = op.getVectorType();
@@ -164,10 +240,45 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 };
 
 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
+private:
+  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+
+public:
   using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
+  TileMulIConversion(const LLVMTypeConverter &typeConverter,
+                     const std::optional<TileScopeAnalysis> &analysis)
+      : ConvertOpToLLVMPattern<TileMulIOp>(typeConverter),
+        enablingAnalysis(analysis) {}
+
   LogicalResult
   matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    bool zexta = op.getIsZextLhs();
+    bool zextb = op.getIsZextRhs();
+
+    if (enablingAnalysis && enablingAnalysis->isValid()) {
+      // Routine for lowering tile Ops with binding info.
+      auto lhsRegIndex = op.getSrcRegIndex();
+      auto rhsRegIndex = op.getRhsRegIndex();
+      auto accRegIndex = op.getAccRegIndex();
+
+      assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+             "Incomplete operation attribute for tile binding");
+      if (zexta && zextb)
+        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud_plain>(
+            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+      else if (zexta && !zextb)
+        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd_plain>(
+            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+      else if (!zexta && zextb)
+        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud_plain>(
+            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+      else
+        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd_plain>(
+            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+      return success();
+    }
+
     VectorType aType = op.getLhsVectorType();
     VectorType bType = op.getRhsVectorType();
     VectorType cType = op.getVectorType();
@@ -178,8 +289,6 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
         getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
     // Replace operation with intrinsic.
     Type resType = typeConverter->convertType(cType);
-    bool zexta = op.getIsZextLhs();
-    bool zextb = op.getIsZextRhs();
     if (zexta && zextb)
       rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
           op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
@@ -203,9 +312,10 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 } // namespace
 
 void mlir::populateAMXLegalizeForLLVMExportPatterns(
-    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    LLVMTypeConverter &converter, std::optional<TileScopeAnalysis> &analysis,
+    RewritePatternSet &patterns) {
   patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
-               TileMulFConversion, TileMulIConversion>(converter);
+               TileMulFConversion, TileMulIConversion>(converter, analysis);
 }
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {

>From eb7672581ce234d62568587b9b16a87a982ebf1b Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 2 Jul 2024 01:30:43 -0700
Subject: [PATCH 08/17] rearrange code structure

---
 .../Dialect/AMX/Analysis/AMXBindingAnalysis.h | 121 ++++
 .../AMX/Analysis/AMXBindingAnalysis.cpp       | 487 +++++++++++++++
 mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt  |  10 +
 mlir/lib/Dialect/AMX/CMakeLists.txt           |   1 +
 .../lib/Dialect/AMX/Transforms/CMakeLists.txt |   1 +
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 562 +-----------------
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |   1 +
 7 files changed, 623 insertions(+), 560 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
 create mode 100644 mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
 create mode 100644 mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt

diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
new file mode 100644
index 00000000000000..05ff7f3d4a74cb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -0,0 +1,121 @@
+//===- AMXBindingAnalysis.h - Analyse AMX Binding Info --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementations of two analysis:
+// 1. TileBindingAnalysis: verify and propagate pre-set tile register binding
+// info for `vector`s used by tile operations.
+// 2. TileScopeAnalysis: verify and find out proper tile configuration
+// domination scopes for tile operations w.r.t correctness and performance.
+// These analysis would be invoked by pass `--enable-amx-tile-binding` and used
+// as indicator for tile binding info validation in pass
+// `--convert-vector-to-llvm`
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_AMX_ANALYSIS_AMXBINDINGANALYSIS_H_
+#define MLIR_DIALECT_AMX_ANALYSIS_AMXBINDINGANALYSIS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+namespace amx {
+
+/// A class for analyzing (propagating) tile register binding for each tile
+/// vector.
+class TileBindingAnalysis {
+private:
+  bool isValidAnalysis;
+  DenseMap<Value, int> bindings;
+
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
+  explicit TileBindingAnalysis(Operation *);
+  bool isValid() const { return isValidAnalysis; }
+  int getBinding(Value val) const {
+    auto iter = bindings.find(val);
+    if (iter == bindings.end())
+      return -1;
+    return iter->second;
+  }
+  void setBinding(Value val, int index) { bindings[val] = index; }
+};
+
+// A class for analyzing tile configuration domination (a.k.a. tile scope).
+class TileScopeAnalysis {
+private:
+  typedef llvm::iterator_range<Block::iterator> BlockSeg;
+  // A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
+  // the length should always be 8.
+  struct PaletteInfo {
+    bool overflow;
+    SmallVector<pair<int, int>, 8> palette;
+    PaletteInfo() {
+      palette.resize(8, {0, 0});
+      clear();
+    }
+    void clear();
+    bool isEmpty(int idx) {
+      return palette[idx].first == 0 && palette[idx].second == 0;
+    }
+    void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
+    void merge(const PaletteInfo &rhs);
+    bool isConflict(const PaletteInfo &rhs);
+  };
+  struct TileScope {
+    // The BlockSeg here is inclusive (containing `end` Op).
+    BlockSeg seg;
+    PaletteInfo pi;
+    TileScope() { clear(); }
+    void clear() { pi.clear(); }
+  };
+
+  bool isValidAnalysis;
+  // Storing parallel Ops that would break tile context & scope.
+  DenseSet<Operation *> parallelOps;
+  // Storing needed palette info for each concerned Op.
+  DenseMap<Operation *, PaletteInfo> neededPalette;
+  // Storing the usage scope for each concerned tile Op.
+  // The BlockSeg here is inclusive (containing `end` Op).
+  DenseMap<Operation *, BlockSeg> tileUsage;
+  // Storing final tile scope results for injecting tilecfg/tilerelease.
+  SmallVector<TileScope, 10> tileScopes;
+
+  void addParallelOp(Operation *op) { parallelOps.insert(op); }
+  bool isParallelOp(Operation *op) {
+    return parallelOps.find(op) == parallelOps.end();
+  }
+
+  void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
+
+  PaletteInfo collectRegionPalette(Region &region);
+  PaletteInfo collectPalette(Operation *op);
+  // Below two functions are the leaf functinos of recursive collection, will
+  // actually insert PaletteInfo into map storage.
+  PaletteInfo collectPaletteForScf(Operation *op);
+  PaletteInfo collectPaletteForTile(Operation *op);
+  std::optional<PaletteInfo> getPalette(Operation *op);
+  std::optional<PaletteInfo> getPalette(BlockSeg seg);
+
+  void doTileScope(Block &block);
+  // The input BlockSeg here is exclusive (not containing `end` Op).
+  void doTileScope(BlockSeg seg);
+  void doTileScope(Operation *op);
+
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
+  explicit TileScopeAnalysis(Operation *);
+  bool isValid() const { return isValidAnalysis; }
+};
+
+} // namespace amx
+} // namespace mlir
+
+#endif // MLIR_DIALECT_AMX_ANALYSIS_AMXBINDINGANALYSIS_H_
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
new file mode 100644
index 00000000000000..8ed992ac498fd9
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -0,0 +1,487 @@
+//===- AMXBindingAnalysis.cpp - Binding info analysis for Intel AMX -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains implementations of two analysis:
+// 1. TileBindingAnalysis
+// 2. TileScopeAnalysis
+// For more details, please refer to header definition.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/BuiltinOps.h"
+
+#define DEBUG_TYPE "amx-binding-analysis"
+
+static bool isTileOp(Operation *op) {
+  return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+         llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+         llvm::isa<TileStoreOp>(op);
+}
+
+template <typename Op>
+static bool TileMulCheck(Operation *op) {
+  auto tile_mul = dyn_cast_or_null<Op>(op);
+  assert(tile_mul);
+
+  auto lhsOp = tile_mul.getLhs().getDefiningOp();
+  auto rhsOp = tile_mul.getRhs().getDefiningOp();
+  auto accOp = tile_mul.getAcc().getDefiningOp();
+  if (!isTileOp(lhsOp) || !isTileOp(rhsOp) || !isTileOp(accOp))
+    return false;
+  return true;
+}
+
+// Not allow mixed use of tile Ops and normal vector Ops, any mixing is
+// considered unacceptable.
+static bool isAcceptableTileOp(Operation *op) {
+  if (!isTileOp(op))
+    return false;
+
+  if (llvm::isa<TileMulFOp>(op)) {
+    return TileMulCheck<TileMulFOp>(op);
+  } else if (llvm::isa<TileMulIOp>(op)) {
+    return TileMulCheck<TileMulIOp>(op);
+  } else if (auto tileStore = dyn_cast_or_null<TileStoreOp>(op)) {
+    auto valOp = tileStore.getVal().getDefiningOp();
+    if (!isTileOp(valOp))
+      return false;
+  }
+  return true;
+}
+
+template <typename Op>
+static bool TileDstPropagate(TileBindingAnalysis *analysis, Operation *op) {
+  auto tileDst = dyn_cast_or_null<Op>(op);
+  assert(tileDst);
+  std::optional<int8_t> tmmIndex = tileDst.getDstRegIndex();
+  if (!tmmIndex) {
+    return false;
+  }
+  analysis->setBinding(tileDst.getRes(), *tmmIndex);
+  return true;
+}
+
+template <typename Op>
+static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
+  auto tileMul = dyn_cast_or_null<Op>(op);
+  assert(tileMul);
+  auto accVal = tileMul.getAcc();
+  auto accIndex = analysis->getBinding(accVal);
+  if (accIndex < 0)
+    return false;
+
+  analysis->setBinding(tileMul.getRes(), accIndex);
+  return true;
+}
+
+TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
+  isValidAnalysis = false;
+  func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+  if (!func)
+    return;
+
+  isValidAnalysis = true;
+  func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+    if (!isValidAnalysis)
+      return;
+    if (!isTileOp(op))
+      return;
+    if (!isAcceptableTileOp(op)) {
+      isValidAnalysis = false;
+      return;
+    }
+
+    if (llvm::isa<TileZeroOp>(op)) {
+      if (!TileDstPropagate<TileZeroOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    } else if (llvm::isa<TileLoadOp>(op)) {
+      if (!TileDstPropagate<TileLoadOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    } else if (llvm::isa<TileMulFOp>(op)) {
+      if (!TileMulPropagate<TileMulFOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    } else if (llvm::isa<TileMulIOp>(op)) {
+      if (!TileMulPropagate<TileMulIOp>(this, op)) {
+        isValidAnalysis = false;
+        return;
+      }
+    }
+  });
+}
+
+void TileScopeAnalysis::PaletteInfo::clear() {
+  overflow = false;
+  for (int idx = 0; idx < 8; idx++)
+    palette[idx] = {0, 0};
+}
+
+void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
+  if (overflow || rhs.overflow) {
+    overflow = true;
+    return;
+  }
+  for (int idx = 0; idx < 8; idx++) {
+    if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+      if (palette[idx].first != rhs.palette[idx].first ||
+          palette[idx].second != rhs.palette[idx].second) {
+        overflow = true;
+        break;
+      }
+    } else if (!rhs.isEmpty(idx)) {
+      palette[idx] = rhs.palette[idx];
+    }
+  }
+}
+
+bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
+  if (overflow || rhs.overflow) {
+    return true;
+  }
+  for (int idx = 0; idx < 8; idx++) {
+    if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
+      if (palette[idx].first != rhs.palette[idx].first ||
+          palette[idx].second != rhs.palette[idx].second) {
+        return true;
+      }
+    }
+  }
+}
+
+// Currently we only operate on scf Ops.
+static bool isConcernedControlFlowOp(Operation *op) {
+  return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+         llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
+         llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+         llvm::isa<scf::WhileOp>(op);
+}
+
+static bool isTileOp(Operation *op) {
+  return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
+         llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
+         llvm::isa<TileStoreOp>(op);
+}
+
+TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
+  isValidAnalysis = false;
+  func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
+  if (!func)
+    return;
+
+  isValidAnalysis = true;
+  // 0. First walk to mark parallel Ops.
+  func->walk<WalkOrder::PostOrder>([this](Operation *op) {
+    if (!isParallelOp(op))
+      return;
+
+    if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+      while (op != root) {
+        addParallelOp(op);
+        op = op->getParentOp();
+      }
+    }
+  });
+
+  // 1. Second walk to collect needed palette for each concerned Op.
+  collectNeededPalette(root);
+
+  // 2. Third walk to analyse usage scope for each tile Op.
+  func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+    if (!isValidAnalysis)
+      return;
+    if (!isTileOp(op))
+      return;
+    Operation *lastUser = nullptr;
+    for (auto user : op->getUsers())
+      lastUser = user;
+    while (lastUser && op->getBlock() != lastUser->getBlock()) {
+      lastUser = lastUser->getParentOp();
+      if (!lastUser)
+        isValidAnalysis = false;
+    }
+    setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
+  });
+  if (!isValidAnalysis)
+    return;
+
+  // 3. Tile scoping for each segmented region in a recursive manner.
+  doTileScope(func.getRegion(0).front());
+}
+
+PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
+  PaletteInfo pi;
+  for (auto op : block.getOps())
+    pi.merge(collectPalette(op));
+  return pi;
+}
+
+PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
+  if (!isValidAnalysis)
+    return PaletteInfo();
+  if (auto func = dyn_cast_or_null<func::FuncOp>(root))
+    // No need to store PaletteInfo for func.
+    return collectRegionPalette(func.getRegion(0).front());
+
+  auto iter = neededPalette.find(op);
+  if (iter != neededPalette.end())
+    return iter->second;
+
+  // For now, we only concern certain control flow Ops and tile Ops.
+  if (isConcernedControlFlowOp(op))
+    return collectPaletteForScf(op);
+  if (isTileOp(op))
+    return collectPaletteForTile(op);
+  return PaletteInfo();
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
+  if (!isConcerendScfOp(op))
+    return PaletteInfo();
+
+  PaletteInfo pi;
+  if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+      llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+    pi = collectNeededPalette(op->getRegion(0).front());
+  } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+    auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
+    auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
+    pi.merge(thenPalette);
+    pi.merge(elsePalette);
+  } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+    pi = collectRegionPalette(indexOp.getDefaultRegion().front());
+    for (auto &caseRegion : indexOp.getCaseRegions()) {
+      pi.merge(collectRegionPalette(caseRegion.front()));
+    }
+  } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+    auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
+    auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
+    pi.merge(beforePalette);
+    pi.merge(afterPalette);
+  }
+  neededPalette[op] = pi;
+  return pi;
+}
+
+static inline pair<int, int> getPaletteShape(VectorType type) {
+  ArrayRef<int64_t> shape = type.getShape();
+  auto elementType = type.getElementType();
+  int typeSize;
+  if (elementType.isInteger(8))
+    typeSize = 1;
+  else if (elementType.isBF16())
+    typeSize = 2;
+  else if (elementType.isInteger(32) || elementType.isF32())
+    typeSize = 4;
+  else
+    assert(false && "Invalid type for palette");
+
+  // Palette shape is { rows, colBytes }.
+  return {shape[0], shape[1] * typeSize};
+}
+
+PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
+  if (!isTileOp(op))
+    return PaletteInfo();
+
+#define PROCESS_UNARY_TILE_OP(op, method)                                      \
+  auto index = op.method();                                                    \
+  if (!index) {                                                                \
+    isValidAnalysis = false;                                                   \
+    return PaletteInfo();                                                      \
+  }                                                                            \
+  pi.set(*index, getPaletteShape(op.getVectorType()));
+
+#define PROCESS_TRINARY_TILE_OP(op)                                            \
+  auto lhsIndex = tileMulFOp.getLhsRegIndex();                                 \
+  auto rhsIndex = tileMulFOp.getRhsRegIndex();                                 \
+  auto accIndex = tileMulFOp.getAccRegIndex();                                 \
+  if (!lhsIndex || !rhsIndex || !accIndex) {                                   \
+    isValidAnalysis = false;                                                   \
+    return PaletteInfo();                                                      \
+  }                                                                            \
+  pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType()));                   \
+  pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType()));                   \
+  pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
+
+  PaletteInfo pi;
+  if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
+    PROCESS_UNARY_TILE_OP(tileLoadOp, getDstRegIndex);
+  } else if (auto tileMulFOp = dyn_cast<TileMulFOp>(op)) {
+    PROCESS_TRINARY_TILE_OP(tileMulFOp);
+  } else if (auto tileMulIOp = dyn_cast<TileMulIOp>(op)) {
+    PROCESS_TRINARY_TILE_OP(tileMulIOp);
+  } else if (auto tileStoreOp = dyn_cast<TileStoreOp>(op)) {
+    PROCESS_UNARY_TILE_OP(tileStoreOp, getSrcRegIndex);
+  } else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
+    PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
+  }
+  neededPalette[op] = pi;
+  return pi;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
+  auto iter = neededPalette.find(op);
+  if (iter == neededPalette.end()) {
+    return std::null_opt;
+  }
+  return iter->second;
+}
+
+std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
+  bool hasPaletteInfo = false;
+  PaletteInfo pi;
+  for (Operation &opIns : seg) {
+    auto *op = &opIns;
+    auto tmpPi = getPalette(&opIns);
+    if (tmpPi) {
+      hasPaletteInfo = true;
+      pi.merge(*tmpPi);
+    }
+  }
+  return hasPaletteInfo ? pi : std::null_opt;
+}
+
+void TileScopeAnalysis::doTileScope(Block &block) {
+  doTileScope(BlockSeg(block.begin(), block.end()));
+}
+
+void TileScopeAnalysis::doTileScope(BlockSeg seg) {
+  if (!isValidAnalysis)
+    return;
+  if (seg.empty())
+    return;
+  SmallVector<BlockSeg, 3> blockSegs;
+  SmallVector<Operation *, 3> paraOps;
+  auto currBegin = seg.begin();
+  for (auto probe = seg.begin(); probe != seg.end(); probe++) {
+    Operation *op = &(*probe);
+    if (isParallelOp(op) {
+      blockSegs.push_back(BlockSeg(currBegin, probe));
+      paraOps.push_back(op);
+      currBegin = probe;
+      currBegin++;
+    }
+  }
+  if (breakers.size()) {
+    assert(blockSegs.size() == paraOps.size());
+    for (int idx = 0; idx < paraOps.size(); idx++) {
+      doTileScope(blockSegs[idx]);
+      doTileScope(paraOps[idx]);
+    }
+    doTileScope(BlockSeg(currBegin, seg.end()));
+    return;
+  }
+
+  // Do tile scope on parallel-free BlockSeg.
+  TileScope currScope;
+  std::optional<Block::iterator> currSegStart;
+  Block::iterator currIter = seg.begin();
+  // Traverse BlockSeg and greedily do tile scoping without look ahead.
+  while (currIter != seg.end()) {
+    Operation *currOp = &(*currIter);
+    if (!currSegStart)
+      currSegStart = currIter;
+
+    Block::iterator nextIterIfMerge;
+    std::optional<PaletteInfo> pi = std::null_opt;
+    if (isConcernedControlFlowOp(currOp)) {
+      pi = getPalette(currOp);
+      nextIterIfMerge = currIter;
+      nextIterIfMerge++;
+    } else if (isTileOp(currOp)) {
+      auto iter = tileUsage.find(currOp);
+      if (iter == tileUsage.end()) {
+        isValidAnalysis = false;
+        return;
+      }
+      pi = getPalette(iter->second);
+      if (pi && pi->overflow) {
+        // This means the binding info in tile Ops exceeds the hardware
+        // capability.
+        isValidAnalysis = false;
+        return;
+      }
+      nextIterIfMerge = iter->second.end();
+      nextIterIfMerge++;
+    }
+    if (!pi) {
+      currIter++;
+      continue;
+    }
+
+#define TRY_ADD_PREVIOUS_SCOPE()                                               \
+  if (currSegStart && *currSegStart != currIter) {                             \
+    auto prevIter = currIter;                                                  \
+    prevIter--;                                                                \
+    currScope.seg = BlockSeg(*currSegStart, prevIter);                         \
+    tileScopes.push_back(currScope);                                           \
+    currScope.clear();                                                         \
+    currSegStart = std::null_opt;                                              \
+  }
+
+    if (pi->overflow) {
+      // Only scf Ops could go through this possibility.
+      TRY_ADD_PREVIOUS_SCOPE();
+      doTileScope(currOp);
+      currIter++;
+    } else {
+      if (currScope.pi.isConflict(*pi)) {
+        TRY_ADD_PREVIOUS_SCOPE();
+        currScope.pi = *pi;
+        currSegStart = currIter;
+        currIter++;
+      } else {
+        currScope.pi.merge(*pi);
+        currIter = nextIterIfMerge;
+      }
+    }
+  }
+
+  TRY_ADD_PREVIOUS_SCOPE();
+}
+
+void TileScopeAnalysis::doTileScope(Operation *op) {
+  // This func try to collect tile scopes for a single control flow Op
+  // This func is not for tile Ops.
+  if (isTileOp(op))
+    return;
+  // Ops that invoke this func are either parallelOps or scfOps with overflowed
+  // paletteInfo, and neither of them can form a tile scope by itself, so we
+  // omit checking self-formed tile scope in this func.
+  if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+      llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+    auto &block = op->getRegion(0).front();
+    doTileScope(BlockSeg(block.begin(), block.end()));
+  } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
+    auto &ifBlock = op->getThenRegion().front();
+    auto &elseBlock = op->getElseRegion().front();
+    doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
+    doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
+  } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
+    auto &defaultBlock = indexOp.getDefaultRegion().front();
+    doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
+    for (auto &caseRegion : indexOp.getCaseRegions()) {
+      auto &caseBlock = indexOp.getDefaultRegion().front();
+      doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
+    }
+  } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
+    auto &beforeBlock = whileOp.getRegion(0).front();
+    auto &afterBlock = whileOp.getRegion(1).front();
+    doTileScope(BlockSeg(beforeBlock.begin(), beforeBlock.end()));
+    doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
+  }
+}
diff --git a/mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt b/mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt
new file mode 100644
index 00000000000000..008fda4357a6b3
--- /dev/null
+++ b/mlir/lib/Dialect/AMX/Analysis/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_dialect_library(MLIRAMXAnalysis
+  AMXBindingAnalysis.cpp
+
+  DEPENDS
+  MLIRAMXConversionsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRAMXDialect
+  MLIRIR
+  )
diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt
index 9f57627c321fb0..cd0e2051f32095 100644
--- a/mlir/lib/Dialect/AMX/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(Analysis)
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index a0c1c10f3c1701..351f788959a6de 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRAMXTransforms
 
   DEPENDS
   MLIRAMXConversionsIncGen
+  MLIRAMXAnalysis
 
   LINK_LIBS PUBLIC
   MLIRAMXDialect
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 54a609fa783e04..3a181ed79d94ce 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -20,6 +20,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
 #include "mlir/Dialect/AMX/Passes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
@@ -36,566 +37,6 @@ namespace amx {
 #define GEN_PASS_DEF_ENABLEAMXTILEBINDING
 #include "mlir/Dialect/AMX/Passes.h.inc"
 
-//===----------------------------------------------------------------------===//
-// Analysis
-//===----------------------------------------------------------------------===//
-
-/// A class for analyzing (propagating) tile register binding for each tile
-/// vector.
-class TileBindingAnalysis {
-private:
-  bool isValidAnalysis;
-  DenseMap<Value, int> bindings;
-
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
-  explicit TileBindingAnalysis(Operation *);
-  bool isValid() const { return isValidAnalysis; }
-  int getBinding(Value val) const {
-    auto iter = bindings.find(val);
-    if (iter == bindings.end())
-      return -1;
-    return iter->second;
-  }
-  void setBinding(Value val, int index) { bindings[val] = index; }
-};
-
-static bool isTileOp(Operation *op) {
-  return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
-         llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
-         llvm::isa<TileStoreOp>(op);
-}
-
-template <typename Op>
-static bool TileMulCheck(Operation *op) {
-  auto tile_mul = dyn_cast_or_null<Op>(op);
-  assert(tile_mul);
-
-  auto lhsOp = tile_mul.getLhs().getDefiningOp();
-  auto rhsOp = tile_mul.getRhs().getDefiningOp();
-  auto accOp = tile_mul.getAcc().getDefiningOp();
-  if (!isTileOp(lhsOp) || !isTileOp(rhsOp) || !isTileOp(accOp))
-    return false;
-  return true;
-}
-
-// Not allow mixed use of tile Ops and normal vector Ops, any mixing is
-// considered unacceptable.
-static bool isAcceptableTileOp(Operation *op) {
-  if (!isTileOp(op))
-    return false;
-
-  if (llvm::isa<TileMulFOp>(op)) {
-    return TileMulCheck<TileMulFOp>(op);
-  } else if (llvm::isa<TileMulIOp>(op)) {
-    return TileMulCheck<TileMulIOp>(op);
-  } else if (auto tileStore = dyn_cast_or_null<TileStoreOp>(op)) {
-    auto valOp = tileStore.getVal().getDefiningOp();
-    if (!isTileOp(valOp))
-      return false;
-  }
-  return true;
-}
-
-template <typename Op>
-static bool TileDstPropagate(TileBindingAnalysis *analysis, Operation *op) {
-  auto tileDst = dyn_cast_or_null<Op>(op);
-  assert(tileDst);
-  std::optional<int8_t> tmmIndex = tileDst.getDstRegIndex();
-  if (!tmmIndex) {
-    return false;
-  }
-  analysis->setBinding(tileDst.getRes(), *tmmIndex);
-  return true;
-}
-
-template <typename Op>
-static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
-  auto tileMul = dyn_cast_or_null<Op>(op);
-  assert(tileMul);
-  auto accVal = tileMul.getAcc();
-  auto accIndex = analysis->getBinding(accVal);
-  if (accIndex < 0)
-    return false;
-
-  analysis->setBinding(tileMul.getRes(), accIndex);
-  return true;
-}
-
-TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
-  isValidAnalysis = false;
-  func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
-  if (!func)
-    return;
-
-  isValidAnalysis = true;
-  func->walk<WalkOrder::PreOrder>([this](Operation *op) {
-    if (!isValidAnalysis)
-      return;
-    if (!isTileOp(op))
-      return;
-    if (!isAcceptableTileOp(op)) {
-      isValidAnalysis = false;
-      return;
-    }
-
-    if (llvm::isa<TileZeroOp>(op)) {
-      if (!TileDstPropagate<TileZeroOp>(this, op)) {
-        isValidAnalysis = false;
-        return;
-      }
-    } else if (llvm::isa<TileLoadOp>(op)) {
-      if (!TileDstPropagate<TileLoadOp>(this, op)) {
-        isValidAnalysis = false;
-        return;
-      }
-    } else if (llvm::isa<TileMulFOp>(op)) {
-      if (!TileMulPropagate<TileMulFOp>(this, op)) {
-        isValidAnalysis = false;
-        return;
-      }
-    } else if (llvm::isa<TileMulIOp>(op)) {
-      if (!TileMulPropagate<TileMulIOp>(this, op)) {
-        isValidAnalysis = false;
-        return;
-      }
-    }
-  });
-}
-
-// A class for analyzing tile configuration domination (a.k.a. tile scope).
-class TileScopeAnalysis {
-private:
-  typedef llvm::iterator_range<Block::iterator> BlockSeg;
-  // A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
-  // the length should always be 8.
-  struct PaletteInfo {
-    bool overflow;
-    SmallVector<pair<int, int>, 8> palette;
-    PaletteInfo() {
-      palette.resize(8, {0, 0});
-      clear();
-    }
-    void clear();
-    bool isEmpty(int idx) {
-      return palette[idx].first == 0 && palette[idx].second == 0;
-    }
-    void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
-    void merge(const PaletteInfo &rhs);
-    bool isConflict(const PaletteInfo &rhs);
-  };
-  struct TileScope {
-    // The BlockSeg here is inclusive (containing `end` Op).
-    BlockSeg seg;
-    PaletteInfo pi;
-    TileScope() { clear(); }
-    void clear() { pi.clear(); }
-  };
-
-  bool isValidAnalysis;
-  // Storing parallel Ops that would break tile context & scope.
-  DenseSet<Operation *> parallelOps;
-  // Storing needed palette info for each concerned Op.
-  DenseMap<Operation *, PaletteInfo> neededPalette;
-  // Storing the usage scope for each concerned tile Op.
-  // The BlockSeg here is inclusive (containing `end` Op).
-  DenseMap<Operation *, BlockSeg> tileUsage;
-  // Storing final tile scope results for injecting tilecfg/tilerelease.
-  SmallVector<TileScope, 10> tileScopes;
-
-  void addParallelOp(Operation *op) { parallelOps.insert(op); }
-  bool isParallelOp(Operation *op) {
-    return parallelOps.find(op) == parallelOps.end();
-  }
-
-  void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
-
-  PaletteInfo collectRegionPalette(Region &region);
-  PaletteInfo collectPalette(Operation *op);
-  // Below two functions are the leaf functinos of recursive collection, will
-  // actually insert PaletteInfo into map storage.
-  PaletteInfo collectPaletteForScf(Operation *op);
-  PaletteInfo collectPaletteForTile(Operation *op);
-  std::optional<PaletteInfo> getPalette(Operation *op);
-  std::optional<PaletteInfo> getPalette(BlockSeg seg);
-
-  void doTileScope(Block &block);
-  // The input BlockSeg here is exclusive (not containing `end` Op).
-  void doTileScope(BlockSeg seg);
-  void doTileScope(Operation *op);
-
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
-  explicit TileScopeAnalysis(Operation *);
-  bool isValid() const { return isValidAnalysis; }
-};
-
-void TileScopeAnalysis::PaletteInfo::clear() {
-  overflow = false;
-  for (int idx = 0; idx < 8; idx++)
-    palette[idx] = {0, 0};
-}
-
-void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
-  if (overflow || rhs.overflow) {
-    overflow = true;
-    return;
-  }
-  for (int idx = 0; idx < 8; idx++) {
-    if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
-      if (palette[idx].first != rhs.palette[idx].first ||
-          palette[idx].second != rhs.palette[idx].second) {
-        overflow = true;
-        break;
-      }
-    } else if (!rhs.isEmpty(idx)) {
-      palette[idx] = rhs.palette[idx];
-    }
-  }
-}
-
-bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
-  if (overflow || rhs.overflow) {
-    return true;
-  }
-  for (int idx = 0; idx < 8; idx++) {
-    if (!isEmpty(idx) && !rhs.isEmpty(idx)) {
-      if (palette[idx].first != rhs.palette[idx].first ||
-          palette[idx].second != rhs.palette[idx].second) {
-        return true;
-      }
-    }
-  }
-}
-
-// Currently we only operate on scf Ops.
-static bool isConcernedControlFlowOp(Operation *op) {
-  return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
-         llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
-         llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
-         llvm::isa<scf::WhileOp>(op);
-}
-
-static bool isTileOp(Operation *op) {
-  return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
-         llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
-         llvm::isa<TileStoreOp>(op);
-}
-
-TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
-  isValidAnalysis = false;
-  func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
-  if (!func)
-    return;
-
-  isValidAnalysis = true;
-  // 0. First walk to mark parallel Ops.
-  func->walk<WalkOrder::PostOrder>([this](Operation *op) {
-    if (!isParallelOp(op))
-      return;
-
-    if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
-      while (op != root) {
-        addParallelOp(op);
-        op = op->getParentOp();
-      }
-    }
-  });
-
-  // 1. Second walk to collect needed palette for each concerned Op.
-  collectNeededPalette(root);
-
-  // 2. Third walk to analyse usage scope for each tile Op.
-  func->walk<WalkOrder::PreOrder>([this](Operation *op) {
-    if (!isValidAnalysis)
-      return;
-    if (!isTileOp(op))
-      return;
-    Operation *lastUser = nullptr;
-    for (auto user : op->getUsers())
-      lastUser = user;
-    while (lastUser && op->getBlock() != lastUser->getBlock()) {
-      lastUser = lastUser->getParentOp();
-      if (!lastUser)
-        isValidAnalysis = false;
-    }
-    setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
-  });
-  if (!isValidAnalysis)
-    return;
-
-  // 3. Tile scoping for each segmented region in a recursive manner.
-  doTileScope(func.getRegion(0).front());
-}
-
-PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
-  PaletteInfo pi;
-  for (auto op : block.getOps())
-    pi.merge(collectPalette(op));
-  return pi;
-}
-
-PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
-  if (!isValidAnalysis)
-    return PaletteInfo();
-  if (auto func = dyn_cast_or_null<func::FuncOp>(root))
-    // No need to store PaletteInfo for func.
-    return collectRegionPalette(func.getRegion(0).front());
-
-  auto iter = neededPalette.find(op);
-  if (iter != neededPalette.end())
-    return iter->second;
-
-  // For now, we only concern certain control flow Ops and tile Ops.
-  if (isConcernedControlFlowOp(op))
-    return collectPaletteForScf(op);
-  if (isTileOp(op))
-    return collectPaletteForTile(op);
-  return PaletteInfo();
-}
-
-PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
-  if (!isConcerendScfOp(op))
-    return PaletteInfo();
-
-  PaletteInfo pi;
-  if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
-      llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
-    pi = collectNeededPalette(op->getRegion(0).front());
-  } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
-    auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
-    auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
-    pi.merge(thenPalette);
-    pi.merge(elsePalette);
-  } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
-    pi = collectRegionPalette(indexOp.getDefaultRegion().front());
-    for (auto &caseRegion : indexOp.getCaseRegions()) {
-      pi.merge(collectRegionPalette(caseRegion.front()));
-    }
-  } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
-    auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
-    auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
-    pi.merge(beforePalette);
-    pi.merge(afterPalette);
-  }
-  neededPalette[op] = pi;
-  return pi;
-}
-
-static inline pair<int, int> getPaletteShape(VectorType type) {
-  ArrayRef<int64_t> shape = type.getShape();
-  auto elementType = type.getElementType();
-  int typeSize;
-  if (elementType.isInteger(8))
-    typeSize = 1;
-  else if (elementType.isBF16())
-    typeSize = 2;
-  else if (elementType.isInteger(32) || elementType.isF32())
-    typeSize = 4;
-  else
-    assert(false && "Invalid type for palette");
-
-  // Palette shape is { rows, colBytes }.
-  return {shape[0], shape[1] * typeSize};
-}
-
-PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
-  if (!isTileOp(op))
-    return PaletteInfo();
-
-#define PROCESS_UNARY_TILE_OP(op, method)                                      \
-  auto index = op.method();                                                    \
-  if (!index) {                                                                \
-    isValidAnalysis = false;                                                   \
-    return PaletteInfo();                                                      \
-  }                                                                            \
-  pi.set(*index, getPaletteShape(op.getVectorType()));
-
-#define PROCESS_TRINARY_TILE_OP(op)                                            \
-  auto lhsIndex = tileMulFOp.getLhsRegIndex();                                 \
-  auto rhsIndex = tileMulFOp.getRhsRegIndex();                                 \
-  auto accIndex = tileMulFOp.getAccRegIndex();                                 \
-  if (!lhsIndex || !rhsIndex || !accIndex) {                                   \
-    isValidAnalysis = false;                                                   \
-    return PaletteInfo();                                                      \
-  }                                                                            \
-  pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType()));                   \
-  pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType()));                   \
-  pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
-
-  PaletteInfo pi;
-  if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
-    PROCESS_UNARY_TILE_OP(tileLoadOp, getDstRegIndex);
-  } else if (auto tileMulFOp = dyn_cast<TileMulFOp>(op)) {
-    PROCESS_TRINARY_TILE_OP(tileMulFOp);
-  } else if (auto tileMulIOp = dyn_cast<TileMulIOp>(op)) {
-    PROCESS_TRINARY_TILE_OP(tileMulIOp);
-  } else if (auto tileStoreOp = dyn_cast<TileStoreOp>(op)) {
-    PROCESS_UNARY_TILE_OP(tileStoreOp, getSrcRegIndex);
-  } else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
-    PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
-  }
-  neededPalette[op] = pi;
-  return pi;
-}
-
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
-  auto iter = neededPalette.find(op);
-  if (iter == neededPalette.end()) {
-    return std::null_opt;
-  }
-  return iter->second;
-}
-
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
-  bool hasPaletteInfo = false;
-  PaletteInfo pi;
-  for (Operation &opIns : seg) {
-    auto *op = &opIns;
-    auto tmpPi = getPalette(&opIns);
-    if (tmpPi) {
-      hasPaletteInfo = true;
-      pi.merge(*tmpPi);
-    }
-  }
-  return hasPaletteInfo ? pi : std::null_opt;
-}
-
-void TileScopeAnalysis::doTileScope(Block &block) {
-  doTileScope(BlockSeg(block.begin(), block.end()));
-}
-
-void TileScopeAnalysis::doTileScope(BlockSeg seg) {
-  if (!isValidAnalysis)
-    return;
-  if (seg.empty())
-    return;
-  SmallVector<BlockSeg, 3> blockSegs;
-  SmallVector<Operation *, 3> paraOps;
-  auto currBegin = seg.begin();
-  for (auto probe = seg.begin(); probe != seg.end(); probe++) {
-    Operation *op = &(*probe);
-    if (isParallelOp(op) {
-      blockSegs.push_back(BlockSeg(currBegin, probe));
-      paraOps.push_back(op);
-      currBegin = probe;
-      currBegin++;
-    }
-  }
-  if (breakers.size()) {
-    assert(blockSegs.size() == paraOps.size());
-    for (int idx = 0; idx < paraOps.size(); idx++) {
-      doTileScope(blockSegs[idx]);
-      doTileScope(paraOps[idx]);
-    }
-    doTileScope(BlockSeg(currBegin, seg.end()));
-    return;
-  }
-
-  // Do tile scope on parallel-free BlockSeg.
-  TileScope currScope;
-  std::optional<Block::iterator> currSegStart;
-  Block::iterator currIter = seg.begin();
-  // Traverse BlockSeg and greedily do tile scoping without look ahead.
-  while (currIter != seg.end()) {
-    Operation *currOp = &(*currIter);
-    if (!currSegStart)
-      currSegStart = currIter;
-
-    Block::iterator nextIterIfMerge;
-    std::optional<PaletteInfo> pi = std::null_opt;
-    if (isConcernedControlFlowOp(currOp)) {
-      pi = getPalette(currOp);
-      nextIterIfMerge = currIter;
-      nextIterIfMerge++;
-    } else if (isTileOp(currOp)) {
-      auto iter = tileUsage.find(currOp);
-      if (iter == tileUsage.end()) {
-        isValidAnalysis = false;
-        return;
-      }
-      pi = getPalette(iter->second);
-      if (pi && pi->overflow) {
-        // This means the binding info in tile Ops exceeds the hardware
-        // capability.
-        isValidAnalysis = false;
-        return;
-      }
-      nextIterIfMerge = iter->second.end();
-      nextIterIfMerge++;
-    }
-    if (!pi) {
-      currIter++;
-      continue;
-    }
-
-#define TRY_ADD_PREVIOUS_SCOPE()                                               \
-  if (currSegStart && *currSegStart != currIter) {                             \
-    auto prevIter = currIter;                                                  \
-    prevIter--;                                                                \
-    currScope.seg = BlockSeg(*currSegStart, prevIter);                         \
-    tileScopes.push_back(currScope);                                           \
-    currScope.clear();                                                         \
-    currSegStart = std::null_opt;                                              \
-  }
-
-    if (pi->overflow) {
-      // Only scf Ops could go through this possibility.
-      TRY_ADD_PREVIOUS_SCOPE();
-      doTileScope(currOp);
-      currIter++;
-    } else {
-      if (currScope.pi.isConflict(*pi)) {
-        TRY_ADD_PREVIOUS_SCOPE();
-        currScope.pi = *pi;
-        currSegStart = currIter;
-        currIter++;
-      } else {
-        currScope.pi.merge(*pi);
-        currIter = nextIterIfMerge;
-      }
-    }
-  }
-
-  TRY_ADD_PREVIOUS_SCOPE();
-}
-
-void TileScopeAnalysis::doTileScope(Operation *op) {
-  // This func try to collect tile scopes for a single control flow Op
-  // This func is not for tile Ops.
-  if (isTileOp(op))
-    return;
-  // Ops that invoke this func are either parallelOps or scfOps with overflowed
-  // paletteInfo, and neither of them can form a tile scope by itself, so we
-  // omit checking self-formed tile scope in this func.
-  if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
-      llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
-    auto &block = op->getRegion(0).front();
-    doTileScope(BlockSeg(block.begin(), block.end()));
-  } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
-    auto &ifBlock = op->getThenRegion().front();
-    auto &elseBlock = op->getElseRegion().front();
-    doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
-    doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
-  } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
-    auto &defaultBlock = indexOp.getDefaultRegion().front();
-    doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
-    for (auto &caseRegion : indexOp.getCaseRegions()) {
-      auto &caseBlock = indexOp.getDefaultRegion().front();
-      doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
-    }
-  } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
-    auto &beforeBlock = whileOp.getRegion(0).front();
-    auto &afterBlock = whileOp.getRegion(1).front();
-    doTileScope(BlockSeg(beforeBlock.begin(), beforeBlock.end()));
-    doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// Pass
-//===----------------------------------------------------------------------===//
-
 class TileStoreBindingRewriter : public OpRewritePattern<TileStoreOp> {
 private:
   TileBindingAnalysis &analysis;
@@ -830,6 +271,7 @@ struct EnableAMXTileBindingPass
       builder.setInsertionPointAfter(end);
       builder.create<amx::x86_amx_tilerelease_plain>(loc);
     }
+    markAnalysesPreserved<TileScopeAnalysis>();
   }
 };
 
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 154987303be673..a85a480593aa38 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
 #include "mlir/Dialect/AMX/Transforms.h"
 
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"

>From c6ef334d8f84a017bea0f4ff70f22901ea3f263b Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 2 Jul 2024 01:36:21 -0700
Subject: [PATCH 09/17] fix tileload & tilestore lowering

---
 mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a85a480593aa38..27754901f55d5b 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -137,7 +137,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
       auto dstRegIndex = op.getDstRegIndex();
       assert(dstRegIndex && "Incomplete operation attribute for tile binding");
       rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64_plain>(
-          op, ptr, stride, *dstRegIndex);
+          op, *dstRegIndex, ptr, stride);
       return success();
     }
 
@@ -181,7 +181,7 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
       auto srcRegIndex = op.getSrcRegIndex();
       assert(srcRegIndex && "Incomplete operation attribute for tile binding");
       rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64_plain>(
-          op, ptr, stride, *srcRegIndex);
+          op, *srcRegIndex, ptr, stride);
       return success();
     }
 

>From 0a0f62e47b09834d439596067baf6482afe7ab53 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Wed, 3 Jul 2024 02:16:54 -0700
Subject: [PATCH 10/17] [TBD] fix some compile issues

---
 .../Dialect/AMX/Analysis/AMXBindingAnalysis.h |  21 +--
 mlir/include/mlir/Dialect/AMX/Transforms.h    |  10 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   3 +-
 .../AMX/Analysis/AMXBindingAnalysis.cpp       | 133 +++++++++++-------
 .../AMX/Transforms/EnableAMXTileBinding.cpp   |  50 ++-----
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |   5 +-
 6 files changed, 122 insertions(+), 100 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index 05ff7f3d4a74cb..38e868e910b114 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -35,6 +35,9 @@ class TileBindingAnalysis {
   bool isValidAnalysis;
   DenseMap<Value, int> bindings;
 
+  // Ensure that tile operations are not wrapped by out-of-scope operations.
+  bool isViableTileOps(Operation *root);
+
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileBindingAnalysis)
   explicit TileBindingAnalysis(Operation *);
@@ -56,24 +59,24 @@ class TileScopeAnalysis {
   // the length should always be 8.
   struct PaletteInfo {
     bool overflow;
-    SmallVector<pair<int, int>, 8> palette;
+    SmallVector<std::pair<int, int>, 8> palette;
     PaletteInfo() {
       palette.resize(8, {0, 0});
       clear();
     }
     void clear();
-    bool isEmpty(int idx) {
+    bool isEmpty(int idx) const {
       return palette[idx].first == 0 && palette[idx].second == 0;
     }
-    void set(int idx, pair<int, int> shape) { palette[idx] = shape; }
+    void set(int idx, std::pair<int, int> shape) { palette[idx] = shape; }
     void merge(const PaletteInfo &rhs);
-    bool isConflict(const PaletteInfo &rhs);
+    bool isConflict(const PaletteInfo &rhs) const;
   };
   struct TileScope {
     // The BlockSeg here is inclusive (containing `end` Op).
     BlockSeg seg;
     PaletteInfo pi;
-    TileScope() { clear(); }
+    TileScope() : seg(Block::iterator(), Block::iterator()) { clear(); }
     void clear() { pi.clear(); }
   };
 
@@ -93,11 +96,13 @@ class TileScopeAnalysis {
     return parallelOps.find(op) == parallelOps.end();
   }
 
-  void setTileUsage(Operation *op, BlockSeg seg) { tileUsage[op] = seg; }
+  void setTileUsage(Operation *op, BlockSeg seg) {
+    tileUsage[op] = std::move(seg);
+  }
 
-  PaletteInfo collectRegionPalette(Region &region);
+  PaletteInfo collectBlockPalette(Block &block);
   PaletteInfo collectPalette(Operation *op);
-  // Below two functions are the leaf functinos of recursive collection, will
+  // Below two functions are the leaf functions of recursive collection, will
   // actually insert PaletteInfo into map storage.
   PaletteInfo collectPaletteForScf(Operation *op);
   PaletteInfo collectPaletteForTile(Operation *op);
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 19608dff6f160f..881c4b460862c2 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -9,6 +9,10 @@
 #ifndef MLIR_DIALECT_AMX_TRANSFORMS_H
 #define MLIR_DIALECT_AMX_TRANSFORMS_H
 
+#include <optional>
+
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
+
 namespace mlir {
 
 class LLVMConversionTarget;
@@ -17,8 +21,10 @@ class RewritePatternSet;
 
 /// Collect a set of patterns to lower AMX ops to ops that map to LLVM
 /// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
-                                              RewritePatternSet &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(
+    LLVMTypeConverter &converter,
+    std::optional<amx::TileScopeAnalysis> &analysis,
+    RewritePatternSet &patterns);
 
 /// Configure the target to support lowering AMX ops to ops that map to LLVM
 /// intrinsics.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index c3fd9399a6c9e6..8b8b80a22ae041 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
@@ -105,7 +106,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
   if (amx) {
-    auto &analysis = getCachedAnalysis<TileScopeAnalysis>();
+    auto &analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
     configureAMXLegalizeForExportTarget(target);
     populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
   }
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 8ed992ac498fd9..2c70704ba32597 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -21,12 +21,23 @@
 
 #define DEBUG_TYPE "amx-binding-analysis"
 
+namespace mlir {
+namespace amx {
+
 static bool isTileOp(Operation *op) {
   return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
          llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
          llvm::isa<TileStoreOp>(op);
 }
 
+// Currently we only operate on scf Ops.
+static bool isConcernedControlFlowOp(Operation *op) {
+  return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
+         llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
+         llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
+         llvm::isa<scf::WhileOp>(op);
+}
+
 template <typename Op>
 static bool TileMulCheck(Operation *op) {
   auto tile_mul = dyn_cast_or_null<Op>(op);
@@ -83,11 +94,38 @@ static bool TileMulPropagate(TileBindingAnalysis *analysis, Operation *op) {
   return true;
 }
 
+bool TileBindingAnalysis::isViableTileOps(Operation *root) {
+  auto func = dyn_cast<func::FuncOp>(root);
+  if (!func)
+    return false;
+
+  bool isViable = true;
+  func->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (!isViable)
+      return;
+    if (!isTileOp(op))
+      return;
+    auto probe = op->getParentOp();
+    while (probe != root) {
+      if (!isConcernedControlFlowOp(probe)) {
+        isViable = false;
+        break;
+      }
+      probe = probe->getParentOp();
+    }
+  });
+  return isViable;
+}
+
 TileBindingAnalysis::TileBindingAnalysis(Operation *root) {
   isValidAnalysis = false;
   func::FuncOp func = dyn_cast_or_null<func::FuncOp>(root);
   if (!func)
     return;
+  // Ensure that tile operations are not wrapped by out-of-scope Ops, else
+  // cannot do enabling.
+  if (!isViableTileOps(root))
+    return;
 
   isValidAnalysis = true;
   func->walk<WalkOrder::PreOrder>([this](Operation *op) {
@@ -148,7 +186,7 @@ void TileScopeAnalysis::PaletteInfo::merge(const PaletteInfo &rhs) {
   }
 }
 
-bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
+bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) const {
   if (overflow || rhs.overflow) {
     return true;
   }
@@ -160,20 +198,7 @@ bool TileScopeAnalysis::PaletteInfo::isConflict(const PaletteInfo &rhs) {
       }
     }
   }
-}
-
-// Currently we only operate on scf Ops.
-static bool isConcernedControlFlowOp(Operation *op) {
-  return llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
-         llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::IfOp>(op) ||
-         llvm::isa<scf::IndexSwitchOp>(op) || llvm::isa<scf::ParallelOp>(op) ||
-         llvm::isa<scf::WhileOp>(op);
-}
-
-static bool isTileOp(Operation *op) {
-  return llvm::isa<TileZeroOp>(op) || llvm::isa<TileLoadOp>(op) ||
-         llvm::isa<TileMulFOp>(op) || llvm::isa<TileMulIOp>(op) ||
-         llvm::isa<TileStoreOp>(op);
+  return false;
 }
 
 TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
@@ -184,7 +209,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
 
   isValidAnalysis = true;
   // 0. First walk to mark parallel Ops.
-  func->walk<WalkOrder::PostOrder>([this](Operation *op) {
+  func->walk<WalkOrder::PostOrder>([=](Operation *op) {
     if (!isParallelOp(op))
       return;
 
@@ -197,10 +222,10 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
   });
 
   // 1. Second walk to collect needed palette for each concerned Op.
-  collectNeededPalette(root);
+  collectPalette(root);
 
   // 2. Third walk to analyse usage scope for each tile Op.
-  func->walk<WalkOrder::PreOrder>([this](Operation *op) {
+  func->walk<WalkOrder::PreOrder>([=](Operation *op) {
     if (!isValidAnalysis)
       return;
     if (!isTileOp(op))
@@ -219,22 +244,28 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
     return;
 
   // 3. Tile scoping for each segmented region in a recursive manner.
-  doTileScope(func.getRegion(0).front());
+  doTileScope(root->getRegion(0).front());
 }
 
-PaletteInfo TileScopeAnalysis::collectRegionPalette(Block &block) {
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectBlockPalette(Block &block) {
   PaletteInfo pi;
-  for (auto op : block.getOps())
-    pi.merge(collectPalette(op));
+  auto beginIter = block.begin();
+  auto endIter = block.end();
+  for (auto iter = beginIter; iter != endIter; iter++) {
+    auto &opRef = *iter;
+    pi.merge(collectPalette(&opRef));
+  }
   return pi;
 }
 
-PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectPalette(Operation *op) {
   if (!isValidAnalysis)
     return PaletteInfo();
-  if (auto func = dyn_cast_or_null<func::FuncOp>(root))
+  if (auto func = dyn_cast_or_null<func::FuncOp>(op))
     // No need to store PaletteInfo for func.
-    return collectRegionPalette(func.getRegion(0).front());
+    return collectBlockPalette(op->getRegion(0).front());
 
   auto iter = neededPalette.find(op);
   if (iter != neededPalette.end())
@@ -248,27 +279,28 @@ PaletteInfo TileScopeAnalysis::collectPalette(Operation *op) {
   return PaletteInfo();
 }
 
-PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
-  if (!isConcerendScfOp(op))
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectPaletteForScf(Operation *op) {
+  if (!isConcernedControlFlowOp(op))
     return PaletteInfo();
 
   PaletteInfo pi;
   if (llvm::isa<scf::ExecuteRegionOp>(op) || llvm::isa<scf::ForOp>(op) ||
       llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
-    pi = collectNeededPalette(op->getRegion(0).front());
+    pi = collectBlockPalette(op->getRegion(0).front());
   } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
-    auto thenPalette = collectRegionPalette(ifOp.getThenRegion().front());
-    auto elsePalette = collectRegionPalette(ifOp.getElseRegion().front());
+    auto thenPalette = collectBlockPalette(ifOp.getThenRegion().front());
+    auto elsePalette = collectBlockPalette(ifOp.getElseRegion().front());
     pi.merge(thenPalette);
     pi.merge(elsePalette);
   } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
-    pi = collectRegionPalette(indexOp.getDefaultRegion().front());
+    pi = collectBlockPalette(indexOp.getDefaultRegion().front());
     for (auto &caseRegion : indexOp.getCaseRegions()) {
-      pi.merge(collectRegionPalette(caseRegion.front()));
+      pi.merge(collectBlockPalette(caseRegion.front()));
     }
   } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
-    auto beforePalette = collectRegionPalette(whileOp.getRegion(0).front());
-    auto afterPalette = collectRegionPalette(whileOp.getRegion(1).front());
+    auto beforePalette = collectBlockPalette(whileOp.getRegion(0).front());
+    auto afterPalette = collectBlockPalette(whileOp.getRegion(1).front());
     pi.merge(beforePalette);
     pi.merge(afterPalette);
   }
@@ -276,7 +308,7 @@ PaletteInfo TileScopeAnalysis::collectPaletteForScf(Operation *op) {
   return pi;
 }
 
-static inline pair<int, int> getPaletteShape(VectorType type) {
+static inline std::pair<int, int> getPaletteShape(VectorType type) {
   ArrayRef<int64_t> shape = type.getShape();
   auto elementType = type.getElementType();
   int typeSize;
@@ -293,7 +325,8 @@ static inline pair<int, int> getPaletteShape(VectorType type) {
   return {shape[0], shape[1] * typeSize};
 }
 
-PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
+TileScopeAnalysis::PaletteInfo
+TileScopeAnalysis::collectPaletteForTile(Operation *op) {
   if (!isTileOp(op))
     return PaletteInfo();
 
@@ -315,7 +348,7 @@ PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
   }                                                                            \
   pi.set(*lhsIndex, getPaletteShape(op.getLhsVectorType()));                   \
   pi.set(*rhsIndex, getPaletteShape(op.getRhsVectorType()));                   \
-  pi.set(*accIndex, getPaletteShape(op.getAccVectorType()));
+  pi.set(*accIndex, getPaletteShape(op.getVectorType()));
 
   PaletteInfo pi;
   if (auto tileLoadOp = dyn_cast<TileLoadOp>(op)) {
@@ -333,26 +366,27 @@ PaletteInfo TileScopeAnalysis::collectPaletteForTile(Operation *op) {
   return pi;
 }
 
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(Operation *op) {
+std::optional<TileScopeAnalysis::PaletteInfo>
+TileScopeAnalysis::getPalette(Operation *op) {
   auto iter = neededPalette.find(op);
   if (iter == neededPalette.end()) {
-    return std::null_opt;
+    return std::nullopt;
   }
   return iter->second;
 }
 
-std::optional<PaletteInfo> TileScopeAnalysis::getPalette(BlockSeg seg) {
+std::optional<TileScopeAnalysis::PaletteInfo>
+TileScopeAnalysis::getPalette(BlockSeg seg) {
   bool hasPaletteInfo = false;
   PaletteInfo pi;
   for (Operation &opIns : seg) {
-    auto *op = &opIns;
     auto tmpPi = getPalette(&opIns);
     if (tmpPi) {
       hasPaletteInfo = true;
       pi.merge(*tmpPi);
     }
   }
-  return hasPaletteInfo ? pi : std::null_opt;
+  return hasPaletteInfo ? std::optional<PaletteInfo>{pi} : std::nullopt;
 }
 
 void TileScopeAnalysis::doTileScope(Block &block) {
@@ -369,14 +403,14 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
   auto currBegin = seg.begin();
   for (auto probe = seg.begin(); probe != seg.end(); probe++) {
     Operation *op = &(*probe);
-    if (isParallelOp(op) {
+    if (isParallelOp(op)) {
       blockSegs.push_back(BlockSeg(currBegin, probe));
       paraOps.push_back(op);
       currBegin = probe;
       currBegin++;
     }
   }
-  if (breakers.size()) {
+  if (paraOps.size()) {
     assert(blockSegs.size() == paraOps.size());
     for (int idx = 0; idx < paraOps.size(); idx++) {
       doTileScope(blockSegs[idx]);
@@ -397,7 +431,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
       currSegStart = currIter;
 
     Block::iterator nextIterIfMerge;
-    std::optional<PaletteInfo> pi = std::null_opt;
+    std::optional<PaletteInfo> pi = std::nullopt;
     if (isConcernedControlFlowOp(currOp)) {
       pi = getPalette(currOp);
       nextIterIfMerge = currIter;
@@ -430,7 +464,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
     currScope.seg = BlockSeg(*currSegStart, prevIter);                         \
     tileScopes.push_back(currScope);                                           \
     currScope.clear();                                                         \
-    currSegStart = std::null_opt;                                              \
+    currSegStart = std::nullopt;                                               \
   }
 
     if (pi->overflow) {
@@ -467,8 +501,8 @@ void TileScopeAnalysis::doTileScope(Operation *op) {
     auto &block = op->getRegion(0).front();
     doTileScope(BlockSeg(block.begin(), block.end()));
   } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
-    auto &ifBlock = op->getThenRegion().front();
-    auto &elseBlock = op->getElseRegion().front();
+    auto &ifBlock = ifOp.getThenRegion().front();
+    auto &elseBlock = ifOp.getElseRegion().front();
     doTileScope(BlockSeg(ifBlock.begin(), ifBlock.end()));
     doTileScope(BlockSeg(elseBlock.begin(), elseBlock.end()));
   } else if (auto indexOp = dyn_cast<scf::IndexSwitchOp>(op)) {
@@ -485,3 +519,6 @@ void TileScopeAnalysis::doTileScope(Operation *op) {
     doTileScope(BlockSeg(afterBlock.begin(), afterBlock.end()));
   }
 }
+
+} // namespace amx
+} // namespace mlir
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 3a181ed79d94ce..484dadc92618cd 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
 #include "mlir/Dialect/AMX/Passes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Support/LogicalResult.h"
@@ -142,38 +143,15 @@ static inline void uint8ArrayToHex(std::string &out, uint8_t array[],
                                    int size) {
   llvm::raw_string_ostream os(out);
   for (int index = 0; index < size; index++) {
-    os << format_hex_no_prefix(array[index], 2, true);
+    os << llvm::format_hex_no_prefix(array[index], 2, true);
   }
 }
 
 struct EnableAMXTileBindingPass
     : public impl::EnableAMXTileBindingBase<EnableAMXTileBindingPass> {
 private:
-  bool isViableTileOps() {
-    Operation *root = getOperation();
-    auto func = dyn_cast<func::FuncOp>(root);
-    if (!func)
-      return false;
-
-    bool isViable = true;
-    func->walk<WalkOrder::PreOrder>([this](Operation *op) {
-      if (!isViable)
-        return;
-      if (!isTileOp(op))
-        return;
-      auto probe = op->getParentOp();
-      while (probe != root) {
-        if (!isConcernedControlFlowOp(probe)) {
-          isViable = false;
-          break;
-        }
-        probe = probe->getParentOp();
-      }
-    });
-    return isViable;
-  }
-
-  LLVM::GlobalOp getOrCreateGlobalPalette(const PaletteInfo &pi) {
+  LLVM::GlobalOp
+  getOrCreateGlobalPalette(const TileScopeAnalysis::PaletteInfo &pi) {
     assert(!pi.overflow && "Expecting valid palette");
 // Pack struct so it can fit into a single 64-byte cache line.
 #pragma pack(push, 1)
@@ -187,7 +165,7 @@ struct EnableAMXTileBindingPass
 #pragma pack(pop)
 
     size_t paletteArraySize = 64;
-    uint8_t *paletteAsArray = &paletteConfig;
+    auto paletteAsArray = reinterpret_cast<uint8_t *>(&paletteConfig);
     memset(paletteAsArray, 0x0, paletteArraySize);
     // Intel AMX: The only legal non-INIT value for palette_id is 1.
     // TODO(haixin): fetch from CPUID ?
@@ -199,14 +177,15 @@ struct EnableAMXTileBindingPass
     }
 
     std::string paletteSymName = "g_intel_amx_palette_";
-    uintArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
+    uint8ArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
 
+    ModuleOp module = getOperation()->template getParentOfType<ModuleOp>();
     if ((global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName)))
       return global;
     // Create a global symbol containing palette config.
-    ModuleOp moduleOp = getOperation()->template getParentOfType<ModuleOp>();
-    OpBuilder builder(moduleOp);
-    builder.setInsertionPointToStart(moduleOp.getBody());
+    auto ctx = module->getContext();
+    OpBuilder builder(module);
+    builder.setInsertionPointToStart(module.getBody());
 
     SmallVector<uint8_t> elementVals;
     for (size_t index = 0; index < paletteArraySize; index++)
@@ -218,18 +197,13 @@ struct EnableAMXTileBindingPass
     auto arrayTy =
         LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
     auto global = builder.create<LLVM::GlobalOp>(
-        moduleOp.getLoc(), arrayType, /*isConstant*/ true,
-        LLVM::Linkage::Private, paletteSymName, dataAttr, /*alignment=*/64);
+        module.getLoc(), arrayType, /*isConstant*/ true, LLVM::Linkage::Private,
+        paletteSymName, dataAttr, /*alignment=*/64);
     return global;
   }
 
 public:
   void runOnOperation() override {
-    // Ensure that tile Ops are not wrapped by out-of-scope Ops, else cannot do
-    // enabling.
-    if (!isViableTileOps())
-      return;
-
     // 0. Get AnalyseInfo for each concerned Value (Does not allow mixed used of
     // tmul & normal vector operations).
     TileBindingAnalysis &bindingAna = getAnalysis<TileBindingAnalysis>();
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 27754901f55d5b..9fd116b38ff740 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h"
 #include "mlir/Dialect/AMX/Transforms.h"
 
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -212,7 +211,7 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
                   ConversionPatternRewriter &rewriter) const override {
     if (enablingAnalysis && enablingAnalysis->isValid()) {
       // Routine for lowering tile Ops with binding info.
-      auto lhsRegIndex = op.getSrcRegIndex();
+      auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
       auto accRegIndex = op.getAccRegIndex();
 
@@ -259,7 +258,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 
     if (enablingAnalysis && enablingAnalysis->isValid()) {
       // Routine for lowering tile Ops with binding info.
-      auto lhsRegIndex = op.getSrcRegIndex();
+      auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
       auto accRegIndex = op.getAccRegIndex();
 

>From aed71c4c0459d4a3e7d97d00449529b507d86a85 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 5 Jul 2024 02:46:38 -0700
Subject: [PATCH 11/17] fix compile issue

---
 mlir/include/mlir/Dialect/AMX/AMX.td          |  6 +--
 .../Dialect/AMX/Analysis/AMXBindingAnalysis.h | 11 ++--
 mlir/include/mlir/Dialect/AMX/Transforms.h    |  2 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  2 +-
 .../AMX/Analysis/AMXBindingAnalysis.cpp       |  5 +-
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 17 +++---
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 53 +++++++++++--------
 7 files changed, 58 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 0c77b8b6ad90fc..0ca220d83ed4fd 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -403,21 +403,21 @@ def LLVM_x86_amx_tdpbssd_plain : AMX_IntrOpBase<"tdpbssd", 0,
 		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
 
 // Dot product of i8 tiles into i32 tile (with sign/zero extension).
-def LLVM_x86_amx_tdpbsud_plain : AMX_IntrOpBase<"tdpbsud", 1,
+def LLVM_x86_amx_tdpbsud_plain : AMX_IntrOpBase<"tdpbsud", 0,
 		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
   Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
                  Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
 		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
 
 // Dot product of i8 tiles into i32 tile (with zero/sign extension).
-def LLVM_x86_amx_tdpbusd_plain : AMX_IntrOpBase<"tdpbusd", 1,
+def LLVM_x86_amx_tdpbusd_plain : AMX_IntrOpBase<"tdpbusd", 0,
 		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
   Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
                  Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
 		 Arg<I8Attr, "Index of rhs tmm registers">:$rhs_index)>;
 
 // Dot product of i8 tiles into i32 tile (with zero/zero extension).
-def LLVM_x86_amx_tdpbuud_plain : AMX_IntrOpBase<"tdpbuud", 1,
+def LLVM_x86_amx_tdpbuud_plain : AMX_IntrOpBase<"tdpbuud", 0,
 		[0, 1, 2], ["dst_index", "lhs_index", "rhs_index"]>,
   Arguments<(ins Arg<I8Attr, "Index of dst tmm registers">:$dst_index,
                  Arg<I8Attr, "Index of lhs tmm registers">:$lhs_index,
diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index 38e868e910b114..e41be602d6cb19 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -53,8 +53,7 @@ class TileBindingAnalysis {
 
 // A class for analyzing tile configuration domination (a.k.a. tile scope).
 class TileScopeAnalysis {
-private:
-  typedef llvm::iterator_range<Block::iterator> BlockSeg;
+public:
   // A list of 2-dim {rows x colBytes} shapes representing tmm register shape,
   // the length should always be 8.
   struct PaletteInfo {
@@ -72,6 +71,9 @@ class TileScopeAnalysis {
     void merge(const PaletteInfo &rhs);
     bool isConflict(const PaletteInfo &rhs) const;
   };
+
+private:
+  typedef llvm::iterator_range<Block::iterator> BlockSeg;
   struct TileScope {
     // The BlockSeg here is inclusive (containing `end` Op).
     BlockSeg seg;
@@ -97,7 +99,7 @@ class TileScopeAnalysis {
   }
 
   void setTileUsage(Operation *op, BlockSeg seg) {
-    tileUsage[op] = std::move(seg);
+    tileUsage.insert({op, std::move(seg)});
   }
 
   PaletteInfo collectBlockPalette(Block &block);
@@ -117,6 +119,9 @@ class TileScopeAnalysis {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TileScopeAnalysis)
   explicit TileScopeAnalysis(Operation *);
+  const SmallVector<TileScope, 10> &getTileScopes() const {
+    return tileScopes;
+  };
   bool isValid() const { return isValidAnalysis; }
 };
 
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 881c4b460862c2..3379dcc7b491b9 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -23,7 +23,7 @@ class RewritePatternSet;
 /// intrinsics.
 void populateAMXLegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter,
-    std::optional<amx::TileScopeAnalysis> &analysis,
+    std::optional<std::reference_wrapper<amx::TileScopeAnalysis>> &analysis,
     RewritePatternSet &patterns);
 
 /// Configure the target to support lowering AMX ops to ops that map to LLVM
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 8b8b80a22ae041..0329a5fa4cb648 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -106,7 +106,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
   if (amx) {
-    auto &analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
+    auto analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
     configureAMXLegalizeForExportTarget(target);
     populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
   }
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 2c70704ba32597..519304df6e0295 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "llvm/Support/Format.h"
 
 #define DEBUG_TYPE "amx-binding-analysis"
 
@@ -412,7 +413,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
   }
   if (paraOps.size()) {
     assert(blockSegs.size() == paraOps.size());
-    for (int idx = 0; idx < paraOps.size(); idx++) {
+    for (size_t idx = 0; idx < paraOps.size(); idx++) {
       doTileScope(blockSegs[idx]);
       doTileScope(paraOps[idx]);
     }
@@ -509,7 +510,7 @@ void TileScopeAnalysis::doTileScope(Operation *op) {
     auto &defaultBlock = indexOp.getDefaultRegion().front();
     doTileScope(BlockSeg(defaultBlock.begin(), defaultBlock.end()));
     for (auto &caseRegion : indexOp.getCaseRegions()) {
-      auto &caseBlock = indexOp.getDefaultRegion().front();
+      auto &caseBlock = caseRegion.front();
       doTileScope(BlockSeg(caseBlock.begin(), caseBlock.end()));
     }
   } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 484dadc92618cd..17736c01daa79c 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -27,6 +27,9 @@
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/FormattedStream.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 
@@ -180,7 +183,7 @@ struct EnableAMXTileBindingPass
     uint8ArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
 
     ModuleOp module = getOperation()->template getParentOfType<ModuleOp>();
-    if ((global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName)))
+    if (auto global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName))
       return global;
     // Create a global symbol containing palette config.
     auto ctx = module->getContext();
@@ -194,7 +197,7 @@ struct EnableAMXTileBindingPass
         {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
     auto dataAttr =
         DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
-    auto arrayTy =
+    auto arrayType =
         LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
     auto global = builder.create<LLVM::GlobalOp>(
         module.getLoc(), arrayType, /*isConstant*/ true, LLVM::Linkage::Private,
@@ -212,9 +215,9 @@ struct EnableAMXTileBindingPass
 
     // 1. Set propagated binding info to AMX Ops.
     RewritePatternSet patterns(&getContext());
-    patterns.add<TileStoreBindingRewriter>(&getContext(), analysis);
-    patterns.add<TileMulFBindingRewriter>(&getContext(), analysis);
-    patterns.add<TileMulIBindingRewriter>(&getContext(), analysis);
+    patterns.add<TileStoreBindingRewriter>(&getContext(), bindingAna);
+    patterns.add<TileMulFBindingRewriter>(&getContext(), bindingAna);
+    patterns.add<TileMulIBindingRewriter>(&getContext(), bindingAna);
     FrozenRewritePatternSet patternSet(std::move(patterns));
 
     if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
@@ -227,13 +230,13 @@ struct EnableAMXTileBindingPass
 
     // 3. Insert tile config/release according to tile scopes.
     OpBuilder builder(getOperation());
-    for (auto &scope : tileScopes) {
+    for (auto &scope : scopeAna.getTileScopes()) {
       assert(!scope.pi.overflow && "Expecting legal AMX palette info");
       auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
       assert(paletteGlobal && "Failed to create global palette");
 
       Operation *begin = &(*scope.seg.begin());
-      Loc loc = begin->getLoc();
+      Location loc = begin->getLoc();
 
       builder.setInsertionPoint(begin);
       Value paletteGlobalPtr =
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 9fd116b38ff740..c7f67b49ef0a53 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -75,19 +75,21 @@ Value getStride(ConversionPatternRewriter &rewriter,
 
 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
 private:
-  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+      &enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
-  TileZeroConversion(const LLVMTypeConverter &typeConverter,
-                     const std::optional<TileScopeAnalysis> &analysis)
+  TileZeroConversion(
+      const LLVMTypeConverter &typeConverter,
+      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
       : ConvertOpToLLVMPattern<TileZeroOp>(typeConverter),
         enablingAnalysis(analysis) {}
 
   LogicalResult
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (enablingAnalysis && enablingAnalysis->isValid()) {
+    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
       // Routine for lowering tile Ops with binding info.
       auto dstRegIndex = op.getDstRegIndex();
       assert(dstRegIndex && "Incomplete operation attribute for tile binding");
@@ -110,12 +112,14 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
 
 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
 private:
-  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+      &enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
-  TileLoadConversion(const LLVMTypeConverter &typeConverter,
-                     const std::optional<TileScopeAnalysis> &analysis)
+  TileLoadConversion(
+      const LLVMTypeConverter &typeConverter,
+      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
       : ConvertOpToLLVMPattern<TileLoadOp>(typeConverter),
         enablingAnalysis(analysis) {}
 
@@ -131,7 +135,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
 
-    if (enablingAnalysis && enablingAnalysis->isValid()) {
+    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
       // Routine for lowering tile Ops with binding info.
       auto dstRegIndex = op.getDstRegIndex();
       assert(dstRegIndex && "Incomplete operation attribute for tile binding");
@@ -154,12 +158,14 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
 
 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
 private:
-  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+      &enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
-  TileStoreConversion(const LLVMTypeConverter &typeConverter,
-                      const std::optional<TileScopeAnalysis> &analysis)
+  TileStoreConversion(
+      const LLVMTypeConverter &typeConverter,
+      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
       : ConvertOpToLLVMPattern<TileStoreOp>(typeConverter),
         enablingAnalysis(analysis) {}
 
@@ -175,7 +181,7 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
 
-    if (enablingAnalysis && enablingAnalysis->isValid()) {
+    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
       // Routine for lowering tile Ops with binding info.
       auto srcRegIndex = op.getSrcRegIndex();
       assert(srcRegIndex && "Incomplete operation attribute for tile binding");
@@ -197,19 +203,21 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
 
 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 private:
-  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+      &enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
-  TileMulFConversion(const LLVMTypeConverter &typeConverter,
-                     const std::optional<TileScopeAnalysis> &analysis)
+  TileMulFConversion(
+      const LLVMTypeConverter &typeConverter,
+      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
       : ConvertOpToLLVMPattern<TileMulFOp>(typeConverter),
         enablingAnalysis(analysis) {}
 
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (enablingAnalysis && enablingAnalysis->isValid()) {
+    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
       // Routine for lowering tile Ops with binding info.
       auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
@@ -241,12 +249,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 
 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 private:
-  const std::optional<TileScopeAnalysis> &enablingAnalysis;
+  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
+      &enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
-  TileMulIConversion(const LLVMTypeConverter &typeConverter,
-                     const std::optional<TileScopeAnalysis> &analysis)
+  TileMulIConversion(
+      const LLVMTypeConverter &typeConverter,
+      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
       : ConvertOpToLLVMPattern<TileMulIOp>(typeConverter),
         enablingAnalysis(analysis) {}
 
@@ -256,7 +266,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
     bool zexta = op.getIsZextLhs();
     bool zextb = op.getIsZextRhs();
 
-    if (enablingAnalysis && enablingAnalysis->isValid()) {
+    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
       // Routine for lowering tile Ops with binding info.
       auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
@@ -312,7 +322,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 } // namespace
 
 void mlir::populateAMXLegalizeForLLVMExportPatterns(
-    LLVMTypeConverter &converter, std::optional<TileScopeAnalysis> &analysis,
+    LLVMTypeConverter &converter,
+    std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis,
     RewritePatternSet &patterns) {
   patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
                TileMulFConversion, TileMulIConversion>(converter, analysis);

>From 626064068aaa81bc3142acbd0fe6c913737e7970 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Thu, 1 Aug 2024 00:01:08 -0700
Subject: [PATCH 12/17] [WIP] preliminary bug fixes

---
 mlir/include/mlir/Dialect/AMX/AMX.td          |  2 +-
 .../Dialect/AMX/Analysis/AMXBindingAnalysis.h |  2 +-
 mlir/include/mlir/InitAllPasses.h             |  2 +
 .../AMX/Analysis/AMXBindingAnalysis.cpp       | 11 +++++
 .../lib/Dialect/AMX/Transforms/CMakeLists.txt |  1 +
 .../AMX/Transforms/EnableAMXTileBinding.cpp   | 12 +++++-
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 41 +++++++++++--------
 7 files changed, 52 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index 0ca220d83ed4fd..c8300317c471aa 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -92,7 +92,7 @@ class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
 
 
 def TileRegisterIndexAttr : OptionalAttr<
-	ConfinedAttr<SI8Attr, [IntMinValue<0>, IntMaxValue<7>]>>;
+	ConfinedAttr<I8Attr, [IntMinValue<0>, IntMaxValue<7>]>>;
 
 //
 // Tile reset.
diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index e41be602d6cb19..4ddeff63d980a4 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -95,7 +95,7 @@ class TileScopeAnalysis {
 
   void addParallelOp(Operation *op) { parallelOps.insert(op); }
   bool isParallelOp(Operation *op) {
-    return parallelOps.find(op) == parallelOps.end();
+    return parallelOps.find(op) != parallelOps.end();
   }
 
   void setTileUsage(Operation *op, BlockSeg seg) {
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index fedd7737f9ea45..827f220af21152 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Conversion/Passes.h"
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+#include "mlir/Dialect/AMX/Passes.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
@@ -93,6 +94,7 @@ inline void registerAllPasses() {
   arm_sve::registerArmSVEPasses();
   emitc::registerEmitCPasses();
   xegpu::registerXeGPUPasses();
+  amx::registerAMXPasses();
 
   // Dialect pipelines
   bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 519304df6e0295..26ba27ba389054 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/Format.h"
 
 #define DEBUG_TYPE "amx-binding-analysis"
@@ -217,6 +218,7 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
     if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
       while (op != root) {
         addParallelOp(op);
+        LLVM_DEBUG(llvm::dbgs() << ">>>>> add parallel op: " << op << "\n");
         op = op->getParentOp();
       }
     }
@@ -305,6 +307,7 @@ TileScopeAnalysis::collectPaletteForScf(Operation *op) {
     pi.merge(beforePalette);
     pi.merge(afterPalette);
   }
+  LLVM_DEBUG(llvm::dbgs() << ">>>>> set needed palette for op: " << op << "\n");
   neededPalette[op] = pi;
   return pi;
 }
@@ -363,6 +366,7 @@ TileScopeAnalysis::collectPaletteForTile(Operation *op) {
   } else if (auto tileZeroOp = dyn_cast<TileZeroOp>(op)) {
     PROCESS_UNARY_TILE_OP(tileZeroOp, getDstRegIndex);
   }
+  LLVM_DEBUG(llvm::dbgs() << ">>>>> set needed palette for op: " << op << "\n");
   neededPalette[op] = pi;
   return pi;
 }
@@ -397,8 +401,10 @@ void TileScopeAnalysis::doTileScope(Block &block) {
 void TileScopeAnalysis::doTileScope(BlockSeg seg) {
   if (!isValidAnalysis)
     return;
+  LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope A\n");
   if (seg.empty())
     return;
+  LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope B\n");
   SmallVector<BlockSeg, 3> blockSegs;
   SmallVector<Operation *, 3> paraOps;
   auto currBegin = seg.begin();
@@ -412,6 +418,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
     }
   }
   if (paraOps.size()) {
+    LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope BB\n");
     assert(blockSegs.size() == paraOps.size());
     for (size_t idx = 0; idx < paraOps.size(); idx++) {
       doTileScope(blockSegs[idx]);
@@ -420,6 +427,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
     doTileScope(BlockSeg(currBegin, seg.end()));
     return;
   }
+  LLVM_DEBUG(llvm::dbgs() << ">>>>> doTileScope C\n");
 
   // Do tile scope on parallel-free BlockSeg.
   TileScope currScope;
@@ -428,6 +436,8 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
   // Traverse BlockSeg and greedily do tile scoping without look ahead.
   while (currIter != seg.end()) {
     Operation *currOp = &(*currIter);
+    LLVM_DEBUG(llvm::dbgs()
+               << ">>>>> doTileScope Checking op: " << *currOp << "\n");
     if (!currSegStart)
       currSegStart = currIter;
 
@@ -486,6 +496,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
     }
   }
 
+  LLVM_DEBUG(llvm::dbgs() << ">>>>> out of doTileScope\n");
   TRY_ADD_PREVIOUS_SCOPE();
 }
 
diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
index 351f788959a6de..1b8feb7f4a544c 100644
--- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRAMXTransforms
 
   DEPENDS
   MLIRAMXConversionsIncGen
+  MLIRAMXTransformsIncGen
   MLIRAMXAnalysis
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 17736c01daa79c..2a37751d387e91 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -213,6 +213,7 @@ struct EnableAMXTileBindingPass
     if (!bindingAna.isValid())
       return;
 
+    LLVM_DEBUG(llvm::dbgs() << ">>> After binding analysis\n");
     // 1. Set propagated binding info to AMX Ops.
     RewritePatternSet patterns(&getContext());
     patterns.add<TileStoreBindingRewriter>(&getContext(), bindingAna);
@@ -220,17 +221,26 @@ struct EnableAMXTileBindingPass
     patterns.add<TileMulIBindingRewriter>(&getContext(), bindingAna);
     FrozenRewritePatternSet patternSet(std::move(patterns));
 
-    if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
+    GreedyRewriteConfig config;
+    config.strictMode = GreedyRewriteStrictness::ExistingOps;
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), patternSet, config)))
       return;
 
+    LLVM_DEBUG(llvm::dbgs() << ">>> After propagating binding info\n");
+
     // 2. Analyse tile scopes & expand them maximally.
     TileScopeAnalysis &scopeAna = getAnalysis<TileScopeAnalysis>();
     if (!scopeAna.isValid())
       return;
 
+    LLVM_DEBUG(llvm::dbgs() << ">>> After tile scope analysis\n");
+
     // 3. Insert tile config/release according to tile scopes.
     OpBuilder builder(getOperation());
     for (auto &scope : scopeAna.getTileScopes()) {
+      LLVM_DEBUG(llvm::dbgs() << ">>> Processing tile scope: "
+                              << *scope.seg.begin() << "\n");
       assert(!scope.pi.overflow && "Expecting legal AMX palette info");
       auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
       assert(paletteGlobal && "Failed to create global palette");
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index c7f67b49ef0a53..afff80640bddd4 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -90,11 +90,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+      rewriter.setInsertionPoint(op);
       // Routine for lowering tile Ops with binding info.
       auto dstRegIndex = op.getDstRegIndex();
       assert(dstRegIndex && "Incomplete operation attribute for tile binding");
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero_plain>(op,
-                                                               *dstRegIndex);
+      rewriter.create<amx::x86_amx_tilezero_plain>(op.getLoc(), *dstRegIndex);
+      rewriter.eraseOp(op);
       return success();
     }
 
@@ -136,11 +137,13 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
                                      adaptor.getIndices(), rewriter);
 
     if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+      rewriter.setInsertionPoint(op);
       // Routine for lowering tile Ops with binding info.
       auto dstRegIndex = op.getDstRegIndex();
       assert(dstRegIndex && "Incomplete operation attribute for tile binding");
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64_plain>(
-          op, *dstRegIndex, ptr, stride);
+      rewriter.create<amx::x86_amx_tileloadd64_plain>(op.getLoc(), *dstRegIndex,
+                                                      ptr, stride);
+      rewriter.eraseOp(op);
       return success();
     }
 
@@ -182,11 +185,13 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
                                      adaptor.getIndices(), rewriter);
 
     if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+      rewriter.setInsertionPoint(op);
       // Routine for lowering tile Ops with binding info.
       auto srcRegIndex = op.getSrcRegIndex();
       assert(srcRegIndex && "Incomplete operation attribute for tile binding");
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64_plain>(
-          op, *srcRegIndex, ptr, stride);
+      rewriter.create<amx::x86_amx_tilestored64_plain>(
+          op.getLoc(), *srcRegIndex, ptr, stride);
+      rewriter.eraseOp(op);
       return success();
     }
 
@@ -218,6 +223,7 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+      rewriter.setInsertionPoint(op);
       // Routine for lowering tile Ops with binding info.
       auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
@@ -225,8 +231,9 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 
       assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
              "Incomplete operation attribute for tile binding");
-      rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps_plain>(
-          op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+      rewriter.create<amx::x86_amx_tdpbf16ps_plain>(op.getLoc(), *accRegIndex,
+                                                    *lhsRegIndex, *rhsRegIndex);
+      rewriter.eraseOp(op);
       return success();
     }
 
@@ -267,6 +274,7 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
     bool zextb = op.getIsZextRhs();
 
     if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+      rewriter.setInsertionPoint(op);
       // Routine for lowering tile Ops with binding info.
       auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
@@ -275,17 +283,18 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
       assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
              "Incomplete operation attribute for tile binding");
       if (zexta && zextb)
-        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud_plain>(
-            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+        rewriter.create<amx::x86_amx_tdpbuud_plain>(op.getLoc(), *accRegIndex,
+                                                    *lhsRegIndex, *rhsRegIndex);
       else if (zexta && !zextb)
-        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd_plain>(
-            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+        rewriter.create<amx::x86_amx_tdpbusd_plain>(op.getLoc(), *accRegIndex,
+                                                    *lhsRegIndex, *rhsRegIndex);
       else if (!zexta && zextb)
-        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud_plain>(
-            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+        rewriter.create<amx::x86_amx_tdpbsud_plain>(op.getLoc(), *accRegIndex,
+                                                    *lhsRegIndex, *rhsRegIndex);
       else
-        rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd_plain>(
-            op, *accRegIndex, *lhsRegIndex, *rhsRegIndex);
+        rewriter.create<amx::x86_amx_tdpbssd_plain>(op.getLoc(), *accRegIndex,
+                                                    *lhsRegIndex, *rhsRegIndex);
+      rewriter.eraseOp(op);
       return success();
     }
 

>From 4438ffcfc3f9584d80d23eae3baaa48b3da9f5d2 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Mon, 5 Aug 2024 02:00:17 -0700
Subject: [PATCH 13/17] succ to run through simple case

---
 .../Dialect/AMX/Analysis/AMXBindingAnalysis.h |  8 ++--
 mlir/include/mlir/Dialect/AMX/Passes.td       |  2 +-
 .../AMX/Analysis/AMXBindingAnalysis.cpp       | 45 ++++++++++++-------
 .../AMX/Transforms/EnableAMXTileBinding.cpp   |  8 ++--
 4 files changed, 37 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
index 4ddeff63d980a4..12c9c8678148ab 100644
--- a/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
+++ b/mlir/include/mlir/Dialect/AMX/Analysis/AMXBindingAnalysis.h
@@ -84,7 +84,7 @@ class TileScopeAnalysis {
 
   bool isValidAnalysis;
   // Storing parallel Ops that would break tile context & scope.
-  DenseSet<Operation *> parallelOps;
+  DenseSet<Operation *> breakOps;
   // Storing needed palette info for each concerned Op.
   DenseMap<Operation *, PaletteInfo> neededPalette;
   // Storing the usage scope for each concerned tile Op.
@@ -93,10 +93,8 @@ class TileScopeAnalysis {
   // Storing final tile scope results for injecting tilecfg/tilerelease.
   SmallVector<TileScope, 10> tileScopes;
 
-  void addParallelOp(Operation *op) { parallelOps.insert(op); }
-  bool isParallelOp(Operation *op) {
-    return parallelOps.find(op) != parallelOps.end();
-  }
+  void addBreakOp(Operation *op) { breakOps.insert(op); }
+  bool isBreakOp(Operation *op) { return breakOps.find(op) != breakOps.end(); }
 
   void setTileUsage(Operation *op, BlockSeg seg) {
     tileUsage.insert({op, std::move(seg)});
diff --git a/mlir/include/mlir/Dialect/AMX/Passes.td b/mlir/include/mlir/Dialect/AMX/Passes.td
index f3c7800b7d8730..518450da1be99f 100644
--- a/mlir/include/mlir/Dialect/AMX/Passes.td
+++ b/mlir/include/mlir/Dialect/AMX/Passes.td
@@ -19,7 +19,7 @@ def EnableAMXTileBinding
     by propagating specified binding information 
     and automatically configuring harware context
   }];
-  let dependentDialects = ["func::FuncDialect"];
+  let dependentDialects = ["func::FuncDialect", "amx::AMXDialect", "LLVM::LLVMDialect"];
 }
 
 #endif // MLIR_DIALECT_AMX_PASSES_TD
diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 26ba27ba389054..68e26db1b04770 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -212,13 +212,14 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
   isValidAnalysis = true;
   // 0. First walk to mark parallel Ops.
   func->walk<WalkOrder::PostOrder>([=](Operation *op) {
-    if (!isParallelOp(op))
+    if (isBreakOp(op))
       return;
 
-    if (llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
+    if (llvm::isa<func::CallOp>(op) || llvm::isa<func::ReturnOp>(op) ||
+        llvm::isa<scf::ForallOp>(op) || llvm::isa<scf::ParallelOp>(op)) {
       while (op != root) {
-        addParallelOp(op);
-        LLVM_DEBUG(llvm::dbgs() << ">>>>> add parallel op: " << op << "\n");
+        addBreakOp(op);
+        LLVM_DEBUG(llvm::dbgs() << ">>>>> add break op: " << *op << "\n");
         op = op->getParentOp();
       }
     }
@@ -233,15 +234,25 @@ TileScopeAnalysis::TileScopeAnalysis(Operation *root) {
       return;
     if (!isTileOp(op))
       return;
-    Operation *lastUser = nullptr;
-    for (auto user : op->getUsers())
-      lastUser = user;
-    while (lastUser && op->getBlock() != lastUser->getBlock()) {
-      lastUser = lastUser->getParentOp();
-      if (!lastUser)
-        isValidAnalysis = false;
+    if (llvm::isa<TileStoreOp>(op)) {
+      LLVM_DEBUG(llvm::dbgs() << ">>>>> setTileUsage: " << *op << " ;;; " << *op
+                              << " ;;; " << *op << "\n");
+      setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(op)));
+
+    } else {
+      Operation *lastUser = nullptr;
+      for (auto user : op->getUsers())
+        lastUser = user;
+      while (lastUser && op->getBlock() != lastUser->getBlock()) {
+        lastUser = lastUser->getParentOp();
+        if (!lastUser)
+          isValidAnalysis = false;
+      }
+      LLVM_DEBUG(llvm::dbgs() << ">>>>> setTileUsage: " << *op << " ;;; " << *op
+                              << " ;;; " << *lastUser << "\n");
+      setTileUsage(op,
+                   BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
     }
-    setTileUsage(op, BlockSeg(Block::iterator(op), Block::iterator(lastUser)));
   });
   if (!isValidAnalysis)
     return;
@@ -410,7 +421,7 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
   auto currBegin = seg.begin();
   for (auto probe = seg.begin(); probe != seg.end(); probe++) {
     Operation *op = &(*probe);
-    if (isParallelOp(op)) {
+    if (isBreakOp(op)) {
       blockSegs.push_back(BlockSeg(currBegin, probe));
       paraOps.push_back(op);
       currBegin = probe;
@@ -448,19 +459,19 @@ void TileScopeAnalysis::doTileScope(BlockSeg seg) {
       nextIterIfMerge = currIter;
       nextIterIfMerge++;
     } else if (isTileOp(currOp)) {
-      auto iter = tileUsage.find(currOp);
-      if (iter == tileUsage.end()) {
+      auto usageIter = tileUsage.find(currOp);
+      if (usageIter == tileUsage.end()) {
         isValidAnalysis = false;
         return;
       }
-      pi = getPalette(iter->second);
+      pi = getPalette(usageIter->second);
       if (pi && pi->overflow) {
         // This means the binding info in tile Ops exceeds the hardware
         // capability.
         isValidAnalysis = false;
         return;
       }
-      nextIterIfMerge = iter->second.end();
+      nextIterIfMerge = usageIter->second.end();
       nextIterIfMerge++;
     }
     if (!pi) {
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 2a37751d387e91..7ef0e0e449891e 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -180,7 +180,8 @@ struct EnableAMXTileBindingPass
     }
 
     std::string paletteSymName = "g_intel_amx_palette_";
-    uint8ArrayToHex(paletteSymName, paletteAsArray, paletteArraySize);
+    uint8ArrayToHex(paletteSymName, paletteAsArray, 2);
+    uint8ArrayToHex(paletteSymName, paletteAsArray + 16, 48);
 
     ModuleOp module = getOperation()->template getParentOfType<ModuleOp>();
     if (auto global = module.lookupSymbol<LLVM::GlobalOp>(paletteSymName))
@@ -239,8 +240,9 @@ struct EnableAMXTileBindingPass
     // 3. Insert tile config/release according to tile scopes.
     OpBuilder builder(getOperation());
     for (auto &scope : scopeAna.getTileScopes()) {
-      LLVM_DEBUG(llvm::dbgs() << ">>> Processing tile scope: "
-                              << *scope.seg.begin() << "\n");
+      LLVM_DEBUG(llvm::dbgs()
+                 << ">>> Processing tile scope: " << *scope.seg.begin()
+                 << " ;;; " << *scope.seg.end() << "\n");
       assert(!scope.pi.overflow && "Expecting legal AMX palette info");
       auto paletteGlobal = getOrCreateGlobalPalette(scope.pi);
       assert(paletteGlobal && "Failed to create global palette");

>From 39d349fcb57e208fb46b8d012756b8695daafbdd Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 13 Aug 2024 00:24:14 -0700
Subject: [PATCH 14/17] fix conversion to LLVM IR

---
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 34 +++++++++++++++----
 1 file changed, 28 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index afff80640bddd4..fe9f8cebda4941 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -94,8 +94,14 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
       // Routine for lowering tile Ops with binding info.
       auto dstRegIndex = op.getDstRegIndex();
       assert(dstRegIndex && "Incomplete operation attribute for tile binding");
-      rewriter.create<amx::x86_amx_tilezero_plain>(op.getLoc(), *dstRegIndex);
-      rewriter.eraseOp(op);
+
+      Location loc = op.getLoc();
+      Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+          loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
+
+      rewriter.create<amx::x86_amx_tilezero_plain>(loc, *dstRegIndex);
+      rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+          op, op.getRes().getType(), dstIndex);
       return success();
     }
 
@@ -141,9 +147,15 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
       // Routine for lowering tile Ops with binding info.
       auto dstRegIndex = op.getDstRegIndex();
       assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+
+      Location loc = op.getLoc();
+      Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+          loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
+
       rewriter.create<amx::x86_amx_tileloadd64_plain>(op.getLoc(), *dstRegIndex,
                                                       ptr, stride);
-      rewriter.eraseOp(op);
+      rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+          op, op.getRes().getType(), dstIndex);
       return success();
     }
 
@@ -231,9 +243,14 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 
       assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
              "Incomplete operation attribute for tile binding");
-      rewriter.create<amx::x86_amx_tdpbf16ps_plain>(op.getLoc(), *accRegIndex,
+      Location loc = op.getLoc();
+      Value accIndex = rewriter.create<LLVM::ConstantOp>(
+          loc, IntegerType::get(rewriter.getContext(), 8), *accRegIndex);
+
+      rewriter.create<amx::x86_amx_tdpbf16ps_plain>(loc, *accRegIndex,
                                                     *lhsRegIndex, *rhsRegIndex);
-      rewriter.eraseOp(op);
+      rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+          op, op.getRes().getType(), accIndex);
       return success();
     }
 
@@ -282,6 +299,10 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 
       assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
              "Incomplete operation attribute for tile binding");
+      Location loc = op.getLoc();
+      Value accIndex = rewriter.create<LLVM::ConstantOp>(
+          loc, IntegerType::get(rewriter.getContext(), 8), *accRegIndex);
+
       if (zexta && zextb)
         rewriter.create<amx::x86_amx_tdpbuud_plain>(op.getLoc(), *accRegIndex,
                                                     *lhsRegIndex, *rhsRegIndex);
@@ -294,7 +315,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
       else
         rewriter.create<amx::x86_amx_tdpbssd_plain>(op.getLoc(), *accRegIndex,
                                                     *lhsRegIndex, *rhsRegIndex);
-      rewriter.eraseOp(op);
+      rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+          op, op.getRes().getType(), accIndex);
       return success();
     }
 

>From ae47af275a2bef42da27c622005ce44e19fbbc07 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Fri, 16 Aug 2024 02:46:19 -0700
Subject: [PATCH 15/17] fix parameter passing

---
 mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index fe9f8cebda4941..d8cccbfcbae592 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -76,7 +76,7 @@ Value getStride(ConversionPatternRewriter &rewriter,
 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
 private:
   const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      &enablingAnalysis;
+      enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
@@ -120,7 +120,7 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
 private:
   const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      &enablingAnalysis;
+      enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
@@ -174,7 +174,7 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
 private:
   const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      &enablingAnalysis;
+      enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
@@ -221,7 +221,7 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 private:
   const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      &enablingAnalysis;
+      enablingAnalysis;
 
 public:
   using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;

>From 37b27986fc49f3dd19d271508304a5edc3e5210c Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Thu, 22 Aug 2024 02:36:32 -0700
Subject: [PATCH 16/17] remove dependency between enable-amx-binding and
 lowering

---
 mlir/include/mlir/Dialect/AMX/Transforms.h    |   6 +-
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   3 +-
 .../AMX/Transforms/EnableAMXTileBinding.cpp   |   1 -
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  | 131 ++++++------------
 4 files changed, 49 insertions(+), 92 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 3379dcc7b491b9..0ce96ede533cf9 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -21,10 +21,8 @@ class RewritePatternSet;
 
 /// Collect a set of patterns to lower AMX ops to ops that map to LLVM
 /// intrinsics.
-void populateAMXLegalizeForLLVMExportPatterns(
-    LLVMTypeConverter &converter,
-    std::optional<std::reference_wrapper<amx::TileScopeAnalysis>> &analysis,
-    RewritePatternSet &patterns);
+void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+                                              RewritePatternSet &patterns);
 
 /// Configure the target to support lowering AMX ops to ops that map to LLVM
 /// intrinsics.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0329a5fa4cb648..89cfca5d9aba37 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -106,9 +106,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
   if (amx) {
-    auto analysis = getCachedAnalysis<amx::TileScopeAnalysis>();
     configureAMXLegalizeForExportTarget(target);
-    populateAMXLegalizeForLLVMExportPatterns(converter, analysis, patterns);
+    populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
   }
   if (x86Vector) {
     configureX86VectorLegalizeForExportTarget(target);
diff --git a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
index 7ef0e0e449891e..a271ec5dab57e1 100644
--- a/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/EnableAMXTileBinding.cpp
@@ -260,7 +260,6 @@ struct EnableAMXTileBindingPass
       builder.setInsertionPointAfter(end);
       builder.create<amx::x86_amx_tilerelease_plain>(loc);
     }
-    markAnalysesPreserved<TileScopeAnalysis>();
   }
 };
 
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index d8cccbfcbae592..e4c04b56b1ceda 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -74,61 +74,46 @@ Value getStride(ConversionPatternRewriter &rewriter,
 }
 
 struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
-private:
-  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      enablingAnalysis;
-
 public:
   using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
-  TileZeroConversion(
-      const LLVMTypeConverter &typeConverter,
-      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
-      : ConvertOpToLLVMPattern<TileZeroOp>(typeConverter),
-        enablingAnalysis(analysis) {}
+  TileZeroConversion(const LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<TileZeroOp>(typeConverter) {}
 
   LogicalResult
   matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
-      rewriter.setInsertionPoint(op);
-      // Routine for lowering tile Ops with binding info.
       auto dstRegIndex = op.getDstRegIndex();
-      assert(dstRegIndex && "Incomplete operation attribute for tile binding");
-
-      Location loc = op.getLoc();
-      Value dstIndex = rewriter.create<LLVM::ConstantOp>(
-          loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
-
-      rewriter.create<amx::x86_amx_tilezero_plain>(loc, *dstRegIndex);
-      rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
-          op, op.getRes().getType(), dstIndex);
+      if (dstRegIndex) {
+        // Routine for lowering tile Ops with binding info.
+        rewriter.setInsertionPoint(op);
+
+        Location loc = op.getLoc();
+        Value dstIndex = rewriter.create<LLVM::ConstantOp>(
+            loc, IntegerType::get(rewriter.getContext(), 8), *dstRegIndex);
+
+        rewriter.create<amx::x86_amx_tilezero_plain>(loc, *dstRegIndex);
+        rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+            op, op.getRes().getType(), dstIndex);
+        return success();
+      }
+
+      VectorType vType = op.getVectorType();
+      // Determine m x n tile sizes.
+      std::pair<Value, Value> tsz =
+          getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+      // Replace operation with intrinsic.
+      Type resType = typeConverter->convertType(vType);
+      rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
+                                                         tsz.second);
       return success();
-    }
-
-    VectorType vType = op.getVectorType();
-    // Determine m x n tile sizes.
-    std::pair<Value, Value> tsz =
-        getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
-    // Replace operation with intrinsic.
-    Type resType = typeConverter->convertType(vType);
-    rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
-                                                       tsz.second);
-    return success();
   }
 };
 
 struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
-private:
-  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      enablingAnalysis;
-
 public:
   using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
-  TileLoadConversion(
-      const LLVMTypeConverter &typeConverter,
-      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
-      : ConvertOpToLLVMPattern<TileLoadOp>(typeConverter),
-        enablingAnalysis(analysis) {}
+  TileLoadConversion(const LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<TileLoadOp>(typeConverter) {}
 
   LogicalResult
   matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
@@ -142,11 +127,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
 
-    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
-      rewriter.setInsertionPoint(op);
+    auto dstRegIndex = op.getDstRegIndex();
+    if (dstRegIndex) {
       // Routine for lowering tile Ops with binding info.
-      auto dstRegIndex = op.getDstRegIndex();
-      assert(dstRegIndex && "Incomplete operation attribute for tile binding");
+      rewriter.setInsertionPoint(op);
 
       Location loc = op.getLoc();
       Value dstIndex = rewriter.create<LLVM::ConstantOp>(
@@ -172,17 +156,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
 };
 
 struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
-private:
-  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      enablingAnalysis;
-
 public:
   using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
-  TileStoreConversion(
-      const LLVMTypeConverter &typeConverter,
-      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
-      : ConvertOpToLLVMPattern<TileStoreOp>(typeConverter),
-        enablingAnalysis(analysis) {}
+  TileStoreConversion(const LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<TileStoreOp>(typeConverter) {}
 
   LogicalResult
   matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
@@ -196,11 +173,10 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
     Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
                                      adaptor.getIndices(), rewriter);
 
-    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
-      rewriter.setInsertionPoint(op);
+    auto srcRegIndex = op.getSrcRegIndex();
+    if (srcRegIndex) {
       // Routine for lowering tile Ops with binding info.
-      auto srcRegIndex = op.getSrcRegIndex();
-      assert(srcRegIndex && "Incomplete operation attribute for tile binding");
+      rewriter.setInsertionPoint(op);
       rewriter.create<amx::x86_amx_tilestored64_plain>(
           op.getLoc(), *srcRegIndex, ptr, stride);
       rewriter.eraseOp(op);
@@ -219,29 +195,22 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
 };
 
 struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
-private:
-  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      enablingAnalysis;
-
 public:
   using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
-  TileMulFConversion(
-      const LLVMTypeConverter &typeConverter,
-      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
-      : ConvertOpToLLVMPattern<TileMulFOp>(typeConverter),
-        enablingAnalysis(analysis) {}
+  TileMulFConversion(const LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<TileMulFOp>(typeConverter) {}
 
   LogicalResult
   matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
-      rewriter.setInsertionPoint(op);
+    auto accRegIndex = op.getAccRegIndex();
+    if (accRegIndex) {
       // Routine for lowering tile Ops with binding info.
+      rewriter.setInsertionPoint(op);
       auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
-      auto accRegIndex = op.getAccRegIndex();
 
-      assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+      assert(lhsRegIndex && rhsRegIndex &&
              "Incomplete operation attribute for tile binding");
       Location loc = op.getLoc();
       Value accIndex = rewriter.create<LLVM::ConstantOp>(
@@ -272,17 +241,10 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
 };
 
 struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
-private:
-  const std::optional<std::reference_wrapper<TileScopeAnalysis>>
-      &enablingAnalysis;
-
 public:
   using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
-  TileMulIConversion(
-      const LLVMTypeConverter &typeConverter,
-      const std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis)
-      : ConvertOpToLLVMPattern<TileMulIOp>(typeConverter),
-        enablingAnalysis(analysis) {}
+  TileMulIConversion(const LLVMTypeConverter &typeConverter)
+      : ConvertOpToLLVMPattern<TileMulIOp>(typeConverter) {}
 
   LogicalResult
   matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
@@ -290,14 +252,14 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
     bool zexta = op.getIsZextLhs();
     bool zextb = op.getIsZextRhs();
 
-    if (enablingAnalysis && enablingAnalysis->get().isValid()) {
+    auto accRegIndex = op.getAccRegIndex();
+    if (accRegIndex) {
       rewriter.setInsertionPoint(op);
       // Routine for lowering tile Ops with binding info.
       auto lhsRegIndex = op.getLhsRegIndex();
       auto rhsRegIndex = op.getRhsRegIndex();
-      auto accRegIndex = op.getAccRegIndex();
 
-      assert(lhsRegIndex && rhsRegIndex && accRegIndex &&
+      assert(lhsRegIndex && rhsRegIndex &&
              "Incomplete operation attribute for tile binding");
       Location loc = op.getLoc();
       Value accIndex = rewriter.create<LLVM::ConstantOp>(
@@ -354,10 +316,9 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
 
 void mlir::populateAMXLegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter,
-    std::optional<std::reference_wrapper<TileScopeAnalysis>> &analysis,
     RewritePatternSet &patterns) {
   patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
-               TileMulFConversion, TileMulIConversion>(converter, analysis);
+               TileMulFConversion, TileMulIConversion>(converter);
 }
 
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {

>From d910874aca188488dd0b2a7770c23c963f580cf4 Mon Sep 17 00:00:00 2001
From: "Huang, Haixin" <haixin.huang at intel.com>
Date: Tue, 3 Sep 2024 00:30:35 -0700
Subject: [PATCH 17/17] tile processing macro minor fix

---
 mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
index 68e26db1b04770..06154b05c0e414 100644
--- a/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
+++ b/mlir/lib/Dialect/AMX/Analysis/AMXBindingAnalysis.cpp
@@ -354,9 +354,9 @@ TileScopeAnalysis::collectPaletteForTile(Operation *op) {
   pi.set(*index, getPaletteShape(op.getVectorType()));
 
 #define PROCESS_TRINARY_TILE_OP(op)                                            \
-  auto lhsIndex = tileMulFOp.getLhsRegIndex();                                 \
-  auto rhsIndex = tileMulFOp.getRhsRegIndex();                                 \
-  auto accIndex = tileMulFOp.getAccRegIndex();                                 \
+  auto lhsIndex = op.getLhsRegIndex();                                         \
+  auto rhsIndex = op.getRhsRegIndex();                                         \
+  auto accIndex = op.getAccRegIndex();                                         \
   if (!lhsIndex || !rhsIndex || !accIndex) {                                   \
     isValidAnalysis = false;                                                   \
     return PaletteInfo();                                                      \



More information about the Mlir-commits mailing list