[llvm] [WIP][X86][AMX] Support AMX constant (PR #92280)
Phoebe Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed May 15 08:13:34 PDT 2024
https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/92280
LLVM bugpiont needs this to redude test case.
>From 0f759fa4dfdace7831de2691d88c8d570705d893 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Wed, 15 May 2024 23:09:56 +0800
Subject: [PATCH] [WIP][X86][AMX] Support AMX constant
LLVM bugpiont needs this to redude test case.
---
llvm/include/llvm-c/Core.h | 1 +
llvm/include/llvm/IR/Constants.h | 21 ++++++++++++++++++
llvm/include/llvm/IR/Value.def | 1 +
.../SelectionDAG/SelectionDAGBuilder.cpp | 6 +++++
llvm/lib/IR/AsmWriter.cpp | 3 ++-
llvm/lib/IR/Constants.cpp | 22 ++++++++++++++++++-
llvm/lib/IR/LLVMContextImpl.h | 2 ++
llvm/lib/IR/Verifier.cpp | 4 ++--
llvm/lib/Target/X86/X86ISelLowering.cpp | 12 ++++++++++
9 files changed, 68 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h
index 9d09546513f0e..6a05b683f3ed8 100644
--- a/llvm/include/llvm-c/Core.h
+++ b/llvm/include/llvm-c/Core.h
@@ -286,6 +286,7 @@ typedef enum {
LLVMInstructionValueKind,
LLVMPoisonValueValueKind,
LLVMConstantTargetNoneValueKind,
+ LLVMConstantAMXNoneValueKind,
} LLVMValueKind;
typedef enum {
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 9ec81903f09c9..9d5bb9ea80411 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -884,6 +884,27 @@ class ConstantTargetNone final : public ConstantData {
}
};
+/// A constant AMX type default initializer
+class ConstantAMXNone final : public ConstantData {
+ friend class Constant;
+
+ explicit ConstantAMXNone(Type *T)
+ : ConstantData(T, Value::ConstantAMXNoneVal) {}
+
+ void destroyConstantImpl();
+
+public:
+ ConstantAMXNone(const ConstantAMXNone &) = delete;
+
+ /// Static factory methods - Return objects of the specified value.
+ static ConstantAMXNone *get(Type *T);
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Value *V) {
+ return V->getValueID() == ConstantAMXNoneVal;
+ }
+};
+
/// The address of a basic block.
///
class BlockAddress final : public Constant {
diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def
index 61f7a87666d09..189dfb8404061 100644
--- a/llvm/include/llvm/IR/Value.def
+++ b/llvm/include/llvm/IR/Value.def
@@ -97,6 +97,7 @@ HANDLE_CONSTANT(ConstantInt)
HANDLE_CONSTANT(ConstantFP)
HANDLE_CONSTANT(ConstantTargetNone)
HANDLE_CONSTANT(ConstantPointerNull)
+HANDLE_CONSTANT(ConstantAMXNone)
HANDLE_CONSTANT(ConstantTokenNone)
HANDLE_CONSTANT_MARKER(ConstantFirstVal, Function)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ca352da5d36eb..704f46ad5e7ac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1885,6 +1885,12 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1));
}
+ if (VT == MVT::x86amx) {
+ assert(C->isNullValue() && "Can only zero this target type!");
+ return DAG.getNode(ISD::BITCAST, getCurSDLoc(), VT,
+ DAG.getConstant(0, getCurSDLoc(), MVT::v256i32));
+ }
+
VectorType *VecTy = cast<VectorType>(V->getType());
// Now that we know the number and type of the elements, get that number of
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 941f6a7a7d823..711813c6cb490 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1564,7 +1564,8 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
return;
}
- if (isa<ConstantAggregateZero>(CV) || isa<ConstantTargetNone>(CV)) {
+ if (isa<ConstantAggregateZero>(CV) || isa<ConstantTargetNone>(CV) ||
+ isa<ConstantAMXNone>(CV)) {
Out << "zeroinitializer";
return;
}
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index db442c54125a7..f12639b3b0005 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -101,7 +101,8 @@ bool Constant::isNullValue() const {
// constant zero is zero for aggregates, cpnull is null for pointers, none for
// tokens.
return isa<ConstantAggregateZero>(this) || isa<ConstantPointerNull>(this) ||
- isa<ConstantTokenNone>(this) || isa<ConstantTargetNone>(this);
+ isa<ConstantTokenNone>(this) || isa<ConstantTargetNone>(this) ||
+ isa<ConstantAMXNone>(this);
}
bool Constant::isAllOnesValue() const {
@@ -391,6 +392,8 @@ Constant *Constant::getNullValue(Type *Ty) {
return ConstantTokenNone::get(Ty->getContext());
case Type::TargetExtTyID:
return ConstantTargetNone::get(cast<TargetExtType>(Ty));
+ case Type::X86_AMXTyID:
+ return ConstantAMXNone::get(Ty);
default:
// Function, Label, or Opaque type?
llvm_unreachable("Cannot create a null constant of that type!");
@@ -1805,6 +1808,23 @@ void ConstantTargetNone::destroyConstantImpl() {
getContext().pImpl->CTNConstants.erase(getType());
}
+//---- ConstantAMXNone::get() implementation.
+//
+
+ConstantAMXNone *ConstantAMXNone::get(Type *Ty) {
+ std::unique_ptr<ConstantAMXNone> &Entry =
+ Ty->getContext().pImpl->CAMXConstants[Ty];
+ if (!Entry)
+ Entry.reset(new ConstantAMXNone(Ty));
+
+ return Entry.get();
+}
+
+/// Remove the constant from the constant table.
+void ConstantAMXNone::destroyConstantImpl() {
+ getContext().pImpl->CAMXConstants.erase(getType());
+}
+
UndefValue *UndefValue::get(Type *Ty) {
std::unique_ptr<UndefValue> &Entry = Ty->getContext().pImpl->UVConstants[Ty];
if (!Entry)
diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h
index 399fe0dad26c7..907d1a800ff04 100644
--- a/llvm/lib/IR/LLVMContextImpl.h
+++ b/llvm/lib/IR/LLVMContextImpl.h
@@ -1549,6 +1549,8 @@ class LLVMContextImpl {
DenseMap<TargetExtType *, std::unique_ptr<ConstantTargetNone>> CTNConstants;
+ DenseMap<Type *, std::unique_ptr<ConstantAMXNone>> CAMXConstants;
+
DenseMap<Type *, std::unique_ptr<UndefValue>> UVConstants;
DenseMap<Type *, std::unique_ptr<PoisonValue>> PVConstants;
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 50f8d6ec84201..70d7198e62c4c 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -5265,9 +5265,9 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
for (Value *V : Call.args()) {
if (auto *MD = dyn_cast<MetadataAsValue>(V))
visitMetadataAsValue(*MD, Call.getCaller());
- if (auto *Const = dyn_cast<Constant>(V))
+ /*if (auto *Const = dyn_cast<Constant>(V))
Check(!Const->getType()->isX86_AMXTy(),
- "const x86_amx is not allowed in argument!");
+ "const x86_amx is not allowed in argument!");*/
}
switch (ID) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a57c10e784d9c..1fa1d2fb79dc2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -43482,6 +43482,18 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
// vxi1 types.
if (DCI.isBeforeLegalize()) {
SDLoc dl(N);
+
+ if (VT == MVT::x86amx) {
+ SDValue Intrin =
+ DAG.getTargetConstant(Intrinsic::x86_tilezero_internal, dl,
+ TLI.getPointerTy(DAG.getDataLayout()));
+ // FIXME: We need to rebuild the Row and Col from its user.
+ SDValue Row = DAG.getConstant(8, dl, MVT::i16);
+ SDValue Col = DAG.getConstant(8, dl, MVT::i16);
+ return DAG.getNode(ISD::INTRINSIC_W_CHAIN, dl, {MVT::x86amx, MVT::Other},
+ {DAG.getEntryNode(), Intrin, Row, Col});
+ }
+
if (SDValue V = combineBitcastvxi1(DAG, VT, N0, dl, Subtarget))
return V;
More information about the llvm-commits
mailing list