[Mlir-commits] [mlir] [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (PR #73253)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 23 08:43:36 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach:

 * Tile allocation can now be done ASAP (i.e. immediately after `-convert-vector-to-arm-sme`)
 * SSA form for control flow is now supported (e.g.`scf.for` loops that yeild tiles)
 * ArmSME ops can be converted to intrinsics very late (i.e. after lowering to control flow)
 * Tests are simplified by removing constants and casts
 * Avoids correctness issues with representing LLVM `immargs` as MLIR values
    - The tile ID on the SME intrinsics is an `immarg` (so is required to be a compile-time constant), `immargs` should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect)
    - Using MLIR values for `immargs` can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```

These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```

The new tile allocation works by operations implementing the `ArmSMETileOpInterface`. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return `std::nullopt` (which is the default behaviour).

Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```

Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases.

Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly operations.

---

Patch is 279.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73253.diff


41 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+3-1) 
- (added) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h (+16) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+38-14) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+152-131) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt (+6) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+6-10) 
- (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+99-98) 
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+19-28) 
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+3-20) 
- (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+2-20) 
- (modified) mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (+60-45) 
- (modified) mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt (-2) 
- (modified) mlir/lib/Dialect/ArmSME/Utils/Utils.cpp (+24-20) 
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+4-6) 
- (modified) mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir (+2-4) 
- (removed) mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir (-51) 
- (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+147-105) 
- (modified) mlir/test/Dialect/ArmSME/canonicalize.mlir (+19-21) 
- (modified) mlir/test/Dialect/ArmSME/cse.mlir (+25-11) 
- (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+11-51) 
- (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+37-156) 
- (modified) mlir/test/Dialect/ArmSME/tile-allocation.mlir (+185-191) 
- (modified) mlir/test/Dialect/ArmSME/tile-zero-masks.mlir (+16-86) 
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+200-210) 
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+2-4) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir (+2-3) 
- (modified) mlir/test/Target/LLVMIR/arm-sme-invalid.mlir (+6-9) 
- (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (+161-170) 
- (modified) mlir/tools/mlir-query/mlir-query.cpp (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index fe1f9062a37ef51..1da8e488a4c4647 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,8 @@
 #define MLIR_DIALECT_ARMSME_IR_ARMSME_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
new file mode 100644
index 000000000000000..430f3571001c8f4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
@@ -0,0 +1,16 @@
+//===- ArmSMEDialect.h - Arm SME Dialect Enums ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
+#define MLIR_DIALECT_ARMSME_ENUMS_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index b75918ebf2f6d9c..2a0167afa8bae9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
   }];
 }
 
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+class ArmSME_IntrOp<string mnemonic,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = [],
+                    list<int> overloadedOperands = [],
                     list<Trait> traits = [], int numResults = 0,
                     list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
@@ -64,16 +67,26 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
           /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/numResults>;
+          /*int numResults=*/numResults,
+          /*bit requiresAccessGroup=*/0,
+          /*bit requiresAliasAnalysis=*/0,
+          /*bit requiresFastmath=*/0,
+          /*list<int> immArgPositions=*/immArgPositions,
+          /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 // Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero",
+                            /*immArgPositions=*/[0],
+                            /*immArgAttrNames=*/["tile_mask"]>,
+                            Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
 
 // MOP's
 class ArmSME_IntrMopOverloadedOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[0],
+      /*immArgAttrNames=*/["tile_id"],
+      /*overloadedOperands=*/[4]>,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
                  Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
                  Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +105,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
 
+class ArmSME_IntrLoadStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[2],
+      /*immArgAttrNames=*/["tile_id"]>;
+
 // Loads
 class ArmSME_IntrLoadOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Load address">:$load_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +131,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
 
 // Stores
 class ArmSME_IntrStoreOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +156,28 @@ def LLVM_aarch64_sme_str
 
 // Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
-    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+    : ArmSME_IntrOp<"write." # direction,
+                    /*immArgPositions=*/[0],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[3],
                     [AllShapesMatch<["predicate", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
 // Tile slice to vector
 class LLVM_aarch64_sme_read<string direction>
-    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+    : ArmSME_IntrOp<"read." # direction,
+                    /*immArgPositions=*/[2],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[],
                     [AllShapesMatch<["vector", "predicate", "res"]>,
                      AllElementTypesMatch<["vector", "res"]>],
                     /*numResults=*/1, /*overloadedResults=*/[0]>,
       Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
-                     Arg<I32, "Virtual tile ID">:$tile_id,
+                     Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ba33a2826e6ca4b..abcc9b649c4a530 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -21,6 +21,99 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
+//===----------------------------------------------------------------------===//
+// ArmSME op interfaces
+//===----------------------------------------------------------------------===//
+
+def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
+    [
+      I32EnumAttrCase<"ZAB", 0, "za.b">,
+      I32EnumAttrCase<"ZAH", 1, "za.h">,
+      I32EnumAttrCase<"ZAS", 2, "za.s">,
+      I32EnumAttrCase<"ZAD", 3, "za.d">,
+      I32EnumAttrCase<"ZAQ", 4, "za.q">,
+    ]>{
+  let cppNamespace = "mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
+  let description = [{
+    An interface for operations that use or allocate Arm SME tiles. These
+    operations need to be assigned a tile ID an i32 attribute, which specifies
+    which virtual tile within the ZA storage to use. The number of tiles
+    available depends on the type of the tile. This is summarized below:
+
+    | Tile Vector Types                                                       | Possible Tile IDs   |
+    |-------------------------------------------------------------------------|---------------------|
+    | `vector<[16]x[16]xi8>`                                                  | 0                   |
+    | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` | 0 and 1             |
+    | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
+    | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
+    | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
+
+    Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
+    the roots for tile allocation, with all operations that (transitively)
+    depend on a root being assigned the same tile ID.
+  }];
+  let methods = [
+    InterfaceMethod<
+      "Sets the tile ID for this operation.",
+      /*returnType=*/"void",
+      /*methodName=*/"setTileId",
+      /*arguments=*/(ins "mlir::IntegerAttr":$tileId),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        if (!tileId)
+          return;
+        ::mlir::Operation* op = this->getOperation();
+        op->setAttr("tile_id", tileId);
+      }]
+    >,
+    InterfaceMethod<
+      "Returns the (possibly null) tile ID assigned to this operation.",
+      /*returnType=*/"mlir::IntegerAttr",
+      /*methodName=*/"getTileId",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        ::mlir::Operation* op = this->getOperation();
+        return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
+      }]
+    >,
+    InterfaceMethod<
+      "The type of tile this operation allocates (or none)",
+      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
+      /*methodName=*/"getAllocatedTileType",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        // Do not allocate a new tile.
+        return std::nullopt;
+      }]
+    >
+  ];
+
+  let extraSharedClassDeclaration = [{
+    // A helper to create a new operation and propagate this operations tile ID.
+    template<typename T, typename... Args>
+    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
+      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
+      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
+        tileOp.setTileId($_op.getTileId());
+      return op;
+    }
+
+    // A helper to replace this operation and forward any tile ID.
+    template<typename T, typename... Args>
+    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
+      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
+      rewriter.replaceOp($_op, newOp);
+      return newOp;
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
@@ -44,7 +137,8 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
-                        "a vector type that fits into a SME tile">
+                        "a vector type that fits into a SME tile",
+                        "VectorType">
 {
   let description = [{
     Possible vector types:
@@ -66,40 +160,6 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
   }];
 }
 
-def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
-                       "an identifier of a virtual tile (of a size) within the ZA storage">
-{
-  let description = [{
-    The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
-    the integer indicates the tile to use, and the bit size indicates the size
-    of tile. The number of tiles available and the element types of those depend
-    on the size. This is summarised below:
-
-    | Tile ID Type | Possible Tile IDs   | Tile Vector Types                                                       |
-    |--------------|---------------------|-------------------------------------------------------------------------|
-    | `i8`         | 0                   | `vector<[16]x[16]xi8>`                                                  |
-    | `i16`        | 0 and 1             | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
-    | `i32`        | 0 to 3 (inclusive)  | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          |
-    | `i64`        | 0 to 7 (inclusive)  | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          |
-    | `i128`       | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>`                                                  |
-  }];
-}
-
-// A type constraint that verifies the bitwidth of the scalar integer returned
-// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
-def TileElementWidthMatchesTileID : TypesMatchWith<
-  "`tile_id` has the same number of bits as elements in `vector`",
-  "vector", "tile_id",
-  "IntegerType::get("
-      "$_self.getContext(),"
-      "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
-          "? ::llvm::cast<IntegerType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth()"
-          ": ::llvm::cast<FloatType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth())">;
-
 class HasMatchingMaskTypeConstraint<string vector, string mask> :
   OptionalTypesMatchWith<
     mask # " has i1 element type and same shape as " # vector,
@@ -162,125 +222,67 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from tile id to 2-d scalable vector type";
+def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+  let summary = "Returns a SME virtual tile";
   let description = [{
-    A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
-    scalable vector type, which represents an SME "virtual tile". This would
-    normally be used when lowering operations that return "virtual tile" vector
-    types to model the output. This is required to preserve dataflow as SME
-    intrinsics have no return values.
+    Allocates a new SME "virtual tile" within a function. The contents of the
+    tile returned from this operation undefined.
 
-    Example:
+    Example 1:
 
-    Input:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 8-bit element "virtual tile"
+    %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
-    After lowering `vector.load`:
+    Example 2:
+
     ```mlir
-    %tile_id = arm_sme.get_tile_id : i32
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-    }
-    %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate two 16-bit element "virtual tiles"
+    %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+    %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
-    In the example above, the `vector.load` can't be replaced with an SME
-    intrinsic that has no outputs since it is used by the `vector.store`.
-    However, by inserting a `cast_tile_to_vector` op after the load intrinsics
-    the `vector.load` can be replaced. This enables "local" rewrites on
-    individual vector ops, rather than "global" rewrites that would have to
-    look at the vector op uses and also lower them.
-
-    Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
-    the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
-  }];
-  let arguments = (ins TileID:$tile_id);
-  let results = (outs SMETile:$vector);
-  let assemblyFormat =
-    "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
-  let hasCanonicalizeMethod = 1;
-}
-
-def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from 2-d scalable vector type to tile id";
-  let description = [{
-    A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
-    type, which represents an SME "virtual tile", to a tile id. This is
-    required to preserve dataflow as the SME intrinsics have no return values.
-
-    Example:
-
-    Input:
+    Example 3:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 128-bit element "virtual tile"
+    %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
+  }];
 
-    After lowering `vector.store`:
-    ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
-      "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
+
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getTile().getType());
     }
-    ```
 
-    Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
-    the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getTileType());
+    }
   }];
-  let arguments = (ins SMETile:$vector);
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat =
-    "$vector attr-dict `:` type($vector) `to` type($tile_id)";
-  let hasCanonicalizeMethod = 1;
 }
 
-def GetTileID : ArmSME_Op<"get_tile_id"> {
-  let summary = "Returns an SME \"virtual tile\" id";
+def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+  let summary = "SME tile placeholder";
   let description = [{
-    A `get_tile_id` operation returns a scalar integer representing an SME
-    "virtual tile" id. The bitwidth of the scalar indicates the element
-    bitwidth of the "virtual tile".
-
-    The scope of a tile id is a function and cannot be passed or returned from
-    functions.
+    A placeholder to preserve dataflow while lowering to SME intrinsics (which
+    do not take or return tile values). This operation is intended to be DCE'd
+    once all ArmSME operations have been lowered.
 
-    Example:
-    ```mlir
-    // Allocate and return an 8-bit element "virtual tile" id
-    %za0_b = arm_sme.get_tile_id : i8
-    ```
-
-    Example:
-    ```
-    // Allocate and return two 16-bit element "virtual tile" ids
-    %za0_h = arm_sme.get_tile_id : i16
-    %za1_h = arm_sme.get_tile_id : i16
-    ```
-
-    Example:
-    ```
-    // Allocate and return an 128-bit element "virtual tile" id
-    %za0_q = arm_sme.get_tile_id : i128
-  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/73253


More information about the Mlir-commits mailing list