[llvm] [WIP][X86][AMX] Support AMX constant (PR #92280)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Wed May 15 08:13:34 PDT 2024


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/92280

LLVM bugpiont needs this to redude test case.

>From 0f759fa4dfdace7831de2691d88c8d570705d893 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Wed, 15 May 2024 23:09:56 +0800
Subject: [PATCH] [WIP][X86][AMX] Support AMX constant

LLVM bugpiont needs this to redude test case.
---
 llvm/include/llvm-c/Core.h                    |  1 +
 llvm/include/llvm/IR/Constants.h              | 21 ++++++++++++++++++
 llvm/include/llvm/IR/Value.def                |  1 +
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  6 +++++
 llvm/lib/IR/AsmWriter.cpp                     |  3 ++-
 llvm/lib/IR/Constants.cpp                     | 22 ++++++++++++++++++-
 llvm/lib/IR/LLVMContextImpl.h                 |  2 ++
 llvm/lib/IR/Verifier.cpp                      |  4 ++--
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 12 ++++++++++
 9 files changed, 68 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 9d09546513f0e..6a05b683f3ed8 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -286,6 +286,7 @@ typedef enum {
   LLVMInstructionValueKind,
   LLVMPoisonValueValueKind,
   LLVMConstantTargetNoneValueKind,
+  LLVMConstantAMXNoneValueKind,
 } LLVMValueKind;
 
 typedef enum {
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 9ec81903f09c9..9d5bb9ea80411 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -884,6 +884,27 @@ class ConstantTargetNone final : public ConstantData {
   }
 };
 
+/// A constant AMX type default initializer
+class ConstantAMXNone final : public ConstantData {
+  friend class Constant;
+
+  explicit ConstantAMXNone(Type *T)
+      : ConstantData(T, Value::ConstantAMXNoneVal) {}
+
+  void destroyConstantImpl();
+
+public:
+  ConstantAMXNone(const ConstantAMXNone &) = delete;
+
+  /// Static factory methods - Return objects of the specified value.
+  static ConstantAMXNone *get(Type *T);
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Value *V) {
+    return V->getValueID() == ConstantAMXNoneVal;
+  }
+};
+
 /// The address of a basic block.
 ///
 class BlockAddress final : public Constant {
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 61f7a87666d09..189dfb8404061 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -97,6 +97,7 @@ HANDLE_CONSTANT(ConstantInt)
 HANDLE_CONSTANT(ConstantFP)
 HANDLE_CONSTANT(ConstantTargetNone)
 HANDLE_CONSTANT(ConstantPointerNull)
+HANDLE_CONSTANT(ConstantAMXNone)
 HANDLE_CONSTANT(ConstantTokenNone)
 
 HANDLE_CONSTANT_MARKER(ConstantFirstVal, Function)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ca352da5d36eb..704f46ad5e7ac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1885,6 +1885,12 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
                          DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1));
     }
 
+    if (VT == MVT::x86amx) {
+      assert(C->isNullValue() && "Can only zero this target type!");
+      return DAG.getNode(ISD::BITCAST, getCurSDLoc(), VT,
+                         DAG.getConstant(0, getCurSDLoc(), MVT::v256i32));
+    }
+
     VectorType *VecTy = cast<VectorType>(V->getType());
 
     // Now that we know the number and type of the elements, get that number of
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 941f6a7a7d823..711813c6cb490 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1564,7 +1564,8 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
     return;
   }
 
-  if (isa<ConstantAggregateZero>(CV) || isa<ConstantTargetNone>(CV)) {
+  if (isa<ConstantAggregateZero>(CV) || isa<ConstantTargetNone>(CV) ||
+      isa<ConstantAMXNone>(CV)) {
     Out << "zeroinitializer";
     return;
   }
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index db442c54125a7..f12639b3b0005 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -101,7 +101,8 @@ bool Constant::isNullValue() const {
   // constant zero is zero for aggregates, cpnull is null for pointers, none for
   // tokens.
   return isa<ConstantAggregateZero>(this) || isa<ConstantPointerNull>(this) ||
-         isa<ConstantTokenNone>(this) || isa<ConstantTargetNone>(this);
+         isa<ConstantTokenNone>(this) || isa<ConstantTargetNone>(this) ||
+         isa<ConstantAMXNone>(this);
 }
 
 bool Constant::isAllOnesValue() const {
@@ -391,6 +392,8 @@ Constant *Constant::getNullValue(Type *Ty) {
     return ConstantTokenNone::get(Ty->getContext());
   case Type::TargetExtTyID:
     return ConstantTargetNone::get(cast<TargetExtType>(Ty));
+  case Type::X86_AMXTyID:
+    return ConstantAMXNone::get(Ty);
   default:
     // Function, Label, or Opaque type?
     llvm_unreachable("Cannot create a null constant of that type!");
@@ -1805,6 +1808,23 @@ void ConstantTargetNone::destroyConstantImpl() {
   getContext().pImpl->CTNConstants.erase(getType());
 }
 
+//---- ConstantAMXNone::get() implementation.
+//
+
+ConstantAMXNone *ConstantAMXNone::get(Type *Ty) {
+  std::unique_ptr<ConstantAMXNone> &Entry =
+      Ty->getContext().pImpl->CAMXConstants[Ty];
+  if (!Entry)
+    Entry.reset(new ConstantAMXNone(Ty));
+
+  return Entry.get();
+}
+
+/// Remove the constant from the constant table.
+void ConstantAMXNone::destroyConstantImpl() {
+  getContext().pImpl->CAMXConstants.erase(getType());
+}
+
 UndefValue *UndefValue::get(Type *Ty) {
   std::unique_ptr<UndefValue> &Entry = Ty->getContext().pImpl->UVConstants[Ty];
   if (!Entry)
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index 399fe0dad26c7..907d1a800ff04 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1549,6 +1549,8 @@ class LLVMContextImpl {
 
   DenseMap<TargetExtType *, std::unique_ptr<ConstantTargetNone>> CTNConstants;
 
+  DenseMap<Type *, std::unique_ptr<ConstantAMXNone>> CAMXConstants;
+
   DenseMap<Type *, std::unique_ptr<UndefValue>> UVConstants;
 
   DenseMap<Type *, std::unique_ptr<PoisonValue>> PVConstants;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 50f8d6ec84201..70d7198e62c4c 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5265,9 +5265,9 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
   for (Value *V : Call.args()) {
     if (auto *MD = dyn_cast<MetadataAsValue>(V))
       visitMetadataAsValue(*MD, Call.getCaller());
-    if (auto *Const = dyn_cast<Constant>(V))
+    /*if (auto *Const = dyn_cast<Constant>(V))
       Check(!Const->getType()->isX86_AMXTy(),
-            "const x86_amx is not allowed in argument!");
+            "const x86_amx is not allowed in argument!");*/
   }
 
   switch (ID) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a57c10e784d9c..1fa1d2fb79dc2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -43482,6 +43482,18 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
   // vxi1 types.
   if (DCI.isBeforeLegalize()) {
     SDLoc dl(N);
+
+    if (VT == MVT::x86amx) {
+      SDValue Intrin =
+          DAG.getTargetConstant(Intrinsic::x86_tilezero_internal, dl,
+                                TLI.getPointerTy(DAG.getDataLayout()));
+      // FIXME: We need to rebuild the Row and Col from its user.
+      SDValue Row = DAG.getConstant(8, dl, MVT::i16);
+      SDValue Col = DAG.getConstant(8, dl, MVT::i16);
+      return DAG.getNode(ISD::INTRINSIC_W_CHAIN, dl, {MVT::x86amx, MVT::Other},
+                         {DAG.getEntryNode(), Intrin, Row, Col});
+    }
+
     if (SDValue V = combineBitcastvxi1(DAG, VT, N0, dl, Subtarget))
       return V;
 



More information about the llvm-commits mailing list