[Mlir-commits] [mlir] [mlir][ArmSME] Support vertical layout in load and store ops (PR #66758)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 19 03:24:20 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
In SME a ZA tile slice is a one-dimensional set of horizontally or vertically contiguous elements within a ZA tile. Currently the load and store ops only support horizontal tile slices. This patch adds a tile slice layout attribute to the load and store ops to support both horizontal and vertical tile slices.
When lowering from Vector dialect horizontal layout is the default.
---
Patch is 98.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66758.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+5)
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+77-39)
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt (+7)
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+9-9)
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+8-6)
- (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+12)
- (modified) mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+71-22)
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+31-11)
- (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir (+3-3)
- (added) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+401)
- (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+450-90)
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+8-8)
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+110)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..1a4984f3bd6ba27 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+ let useDefaultAttributePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
"::llvm::cast<VectorType>($_self).getElementType())"
".getWidth())">;
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+ I32EnumAttrCase<"Horizontal", 0, "hor">,
+ I32EnumAttrCase<"Vertical", 1, "ver">,
+]> {
+ let cppNamespace = "::mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+ "layout"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
@@ -239,27 +259,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
- with the shape defined by the 2D scalable vector type of the result tile.
- The slice of memory must be contiguous. The memref must be either rank 1 or
- rank 2 with dynamic dimensions, since the operation is scalable, and the
- element type must be a scalar that matches the element type of the result.
+ with the shape defined by the 2D scalable vector type of the result tile. A
+ tile slice layout attribute specifies whether the slices of the tile being
+ loaded are horizontal or vertical. The slice of memory must be contiguous.
+ The memref must be either rank 1 or rank 2 with dynamic dimensions, since
+ the operation is scalable, and the element type must be a scalar that
+ matches the element type of the result.
+
+ The default tile slice layout when lowering from higher-level dialects is
+ horizontal.
- Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+ Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a FP 32-bit element ZA tile from memory.
+ Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a 128-bit element ZA tile from memory.
+ Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs SMETile:$result);
@@ -274,7 +300,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}];
let assemblyFormat =
- "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+ "$layout `,` $base `[` $indices `]` attr-dict "
+ "`:` type($base) `,` type($result)";
}
def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,27 +309,32 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the tile being
- stored. The slice of memory must be contiguous. The memref must be either
- rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
- and the element type must be a scalar that matches the element type of the
- result.
+ stored. A tile slice layout attribute specifies whether the slices of the
+ tile being stored are horizontal or vertical. The slice of memory must be
+ contiguous. The memref must be either rank 1 or rank 2 with dynamic
+ dimensions, since the operation is scalable, and the element type must be a
+ scalar that matches the element type of the result.
+
+ The default tile slice layout when lowering from higher-level dialects is
+ horizontal.
- Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+ Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store a FP 32-bit element ZA tile to memory.
+ Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a 128-bit element ZA tile to memory.
+ Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$valueToStore,
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
@@ -314,8 +346,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];
- let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
- "`:` type($base) `,` type($valueToStore)";
+ let assemblyFormat =
+ "$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
+ "`:` type($base) `,` type($valueToStore)";
}
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,29 +359,32 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is loaded to. The updated tile is returned as the result.
+ slice is loaded to. A tile slice layout attribute specifies whether the
+ tile slice being loaded at the given index is horizontal or vertical. The
+ updated tile is returned as the result.
The slice of memory read is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.
- Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+ Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+ Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+ Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to load from">:$base,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
let results = (outs SMETile:$result);
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];
let assemblyFormat = [{
- $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+ $layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
attr-dict `:` type($base) `,` type($result)
}];
}
@@ -374,29 +410,31 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is stored from.
+ slice is stored from. A tile slice layout attribute specifies whether the
+ tile slice being stored from the given index is horizontal or vertical.
The slice of memory written is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the input tile.
- Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+ Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+ Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+ Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
@@ -409,7 +447,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];
let assemblyFormat = [{
- $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+ $tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
attr-dict `:` type($base) `,` type($tile)
}];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..7afd0d014541687 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
set(LLVM_TARGET_DEFINITIONS ArmSME.td)
mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
+add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..86cabe67f2695f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -54,7 +54,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
///
/// BEFORE:
/// ```mlir
-/// %tile = arm_sme.tile_load %src[%c0, %c0] :
+/// %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] :
/// memref<?x?xi32>, vector<[4]x[4]xi32>
/// ```
///
@@ -68,7 +68,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+/// %tile_update = arm_sme.load_tile_slice <hor>, %src[%tile_slice_idx],
/// %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
- rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
- tileLoadOp.getBase(), tile,
- memrefIndices, tileSliceIndex);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
+ memrefIndices, tileSliceIndex);
rewriter.setInsertionPointAfter(forOp);
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
-/// arm_sme.tile_store %tile, %dest[%c0, %c0]
+/// arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0]
/// : memref<?x?xi32>, vector<[4]x[4]xi32
/// ```
///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-/// : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// arm_sme.store_tile_slice %tile, %tile_slice_idx, <ver>,
+/// %dest[%tile_slice_idx] : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
numTileSlices, memrefIndices, loc, rewriter);
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
- tileStoreOp.getBase(), memrefIndices);
+ tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..feaec0e035ed9fd 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -65,8 +65,8 @@ namespace {
///
/// is converted to:
///
-/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
-/// vector<[16]x[16]xi8>
+/// arm_sme.tile_store %vector, <hor>, %source[%c0, %c0]
+/// : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -81,8 +81,8 @@ struct TransferWriteToArmSMELowering
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- writeOp, writeOp.getVector(), writeOp.getSource(),
- writeOp.getIndices());
+ writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
+ writeOp.getSource(), writeOp.getIndices());
return success();
}
};
@@ -97,7 +97,8 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
- load, load.getVectorType(), load.getBase(), load.getIndices());
+ load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
+ load.getBase(), load.getIndices());
return success();
}
@@ -113,7 +114,8 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- store, store.getValueToStore(), store.getBase(), store.getIndices());
+ store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
+ store.getBase(), store.getIndices());
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..92fb146691a0beb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arm_sme;
@@ -22,13 +24,23 @@ using namespace mlir::arm_sme;
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+
void ArmSMEDialect::initialize() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+ >();
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 9b6332a478ad...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/66758
More information about the Mlir-commits
mailing list