[clang] [llvm] [CIR] X86 vector masked load builtins (PR #169464)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 24 23:10:54 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
Author: woruyu (woruyu)
<details>
<summary>Changes</summary>
### Summary
This PR resolves https://github.com/llvm/llvm-project/issues/167752. Just for masked load parts.
---
Full diff: https://github.com/llvm/llvm-project/pull/169464.diff
6 Files Affected:
- (modified) clang/lib/CIR/CodeGen/CIRGenBuilder.h (+28)
- (modified) clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp (+42)
- (added) clang/test/CIR/CodeGen/X86/avx512vl-builtins.c (+18)
- (modified) llvm/include/llvm/IR/Intrinsics.td (+3-2)
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+1)
- (modified) llvm/lib/IR/Verifier.cpp (+5-2)
``````````diff
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
index 85b38120169fd..e65dcf7531bfe 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h
+++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
@@ -603,6 +603,34 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
addr.getAlignment().getAsAlign().value());
}
+ /// Create a call to a Masked Load intrinsic.
+ /// \p loc - expression location
+ /// \p ty - vector type to load
+ /// \p ptr - base pointer for the load
+ /// \p alignment - alignment of the source location
+ /// \p mask - vector of booleans which indicates what vector lanes should
+ /// be accessed in memory
+ /// \p passThru - pass-through value that is used to fill the masked-off
+ /// lanes
+ /// of the result
+ mlir::Value createMaskedLoad(mlir::Location loc, mlir::Type ty,
+ mlir::Value ptr, llvm::Align alignment,
+ mlir::Value mask, mlir::Value passThru) {
+
+ assert(mlir::isa<cir::VectorType>(ty) && "Type should be vector");
+ assert(mask && "Mask should not be all-ones (null)");
+
+ if (!passThru)
+ passThru = this->getConstant(loc, cir::PoisonAttr::get(ty));
+
+ mlir::Value ops[] = {ptr, this->getUInt32(int32_t(alignment.value()), loc),
+ mask, passThru};
+
+ return cir::LLVMIntrinsicCallOp::create(
+ *this, loc, getStringAttr("masked.load"), ty, ops)
+ .getResult();
+ }
+
cir::VecShuffleOp
createVecShuffle(mlir::Location loc, mlir::Value vec1, mlir::Value vec2,
llvm::ArrayRef<mlir::Attribute> maskAttrs) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
index 978fee7dbec9d..6a73227a7baf7 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
@@ -33,6 +33,40 @@ static mlir::Value emitIntrinsicCallOp(CIRGenFunction &cgf, const CallExpr *e,
.getResult();
}
+// Convert the mask from an integer type to a vector of i1.
+static mlir::Value getMaskVecValue(CIRGenFunction &cgf, mlir::Value mask,
+ unsigned numElts, mlir::Location loc) {
+ cir::VectorType maskTy =
+ cir::VectorType::get(cgf.getBuilder().getSIntNTy(1),
+ cast<cir::IntType>(mask.getType()).getWidth());
+
+ mlir::Value maskVec = cgf.getBuilder().createBitcast(mask, maskTy);
+
+ // If we have less than 8 elements, then the starting mask was an i8 and
+ // we need to extract down to the right number of elements.
+ if (numElts < 8) {
+ llvm::SmallVector<int64_t, 4> indices;
+ for (unsigned i = 0; i != numElts; ++i)
+ indices.push_back(i);
+ maskVec = cgf.getBuilder().createVecShuffle(loc, maskVec, maskVec, indices);
+ }
+
+ return maskVec;
+}
+
+static mlir::Value emitX86MaskedLoad(CIRGenFunction &cgf,
+ ArrayRef<mlir::Value> ops,
+ llvm::Align alignment,
+ mlir::Location loc) {
+ mlir::Type ty = ops[1].getType();
+ mlir::Value ptr = ops[0];
+ mlir::Value maskVec =
+ getMaskVecValue(cgf, ops[2], cast<cir::VectorType>(ty).getSize(), loc);
+
+ return cgf.getBuilder().createMaskedLoad(loc, ty, ptr, alignment, maskVec,
+ ops[1]);
+}
+
// OG has unordered comparison as a form of optimization in addition to
// ordered comparison, while CIR doesn't.
//
@@ -327,6 +361,11 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
case X86::BI__builtin_ia32_movdqa64store512_mask:
case X86::BI__builtin_ia32_storeaps512_mask:
case X86::BI__builtin_ia32_storeapd512_mask:
+ cgm.errorNYI(expr->getSourceRange(),
+ std::string("unimplemented X86 builtin call: ") +
+ getContext().BuiltinInfo.getName(builtinID));
+ return {};
+
case X86::BI__builtin_ia32_loadups128_mask:
case X86::BI__builtin_ia32_loadups256_mask:
case X86::BI__builtin_ia32_loadups512_mask:
@@ -345,6 +384,9 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
case X86::BI__builtin_ia32_loaddqudi128_mask:
case X86::BI__builtin_ia32_loaddqudi256_mask:
case X86::BI__builtin_ia32_loaddqudi512_mask:
+ return emitX86MaskedLoad(*this, ops, llvm::Align(1),
+ getLoc(expr->getExprLoc()));
+
case X86::BI__builtin_ia32_loadsbf16128_mask:
case X86::BI__builtin_ia32_loadsh128_mask:
case X86::BI__builtin_ia32_loadss128_mask:
diff --git a/clang/test/CIR/CodeGen/X86/avx512vl-builtins.c b/clang/test/CIR/CodeGen/X86/avx512vl-builtins.c
new file mode 100644
index 0000000000000..2029e3d4b3734
--- /dev/null
+++ b/clang/test/CIR/CodeGen/X86/avx512vl-builtins.c
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -fclangir -emit-cir -o %t.cir -Wall -Werror -Wsign-conversion
+// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -fclangir -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
+// RUN: FileCheck --check-prefixes=LLVM --input-file=%t.ll %s
+
+
+#include <immintrin.h>
+
+__m128 test_mm_mask_loadu_ps(__m128 __W, __mmask8 __U, void const *__P) {
+ // CIR-LABEL: _mm_mask_loadu_ps
+ // CIR: {{%.*}} = cir.call_llvm_intrinsic "masked.load" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (!cir.ptr<!cir.vector<4 x !cir.float>>, !u32i, !cir.vector<4 x !cir.int<s, 1>>, !cir.vector<4 x !cir.float>) -> !cir.vector<4 x !cir.float>
+
+ // LLVM-LABEL: @test_mm_mask_loadu_ps
+ // LLVM: @llvm.masked.load.v4f32.p0(ptr %{{.*}}, i32 1, <4 x i1> %{{.*}}, <4 x float> %{{.*}})
+ return _mm_mask_loadu_ps(__W, __U, __P);
+}
+
+
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8f3cc54747074..355a2b85defd4 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2524,9 +2524,10 @@ def int_vp_is_fpclass:
//
def int_masked_load:
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
- [llvm_anyptr_ty,
+ [llvm_anyptr_ty, llvm_i32_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMMatchType<0>],
- [IntrReadMem, IntrArgMemOnly, NoCapture<ArgIndex<0>>]>;
+ [IntrReadMem, IntrArgMemOnly, ImmArg<ArgIndex<1>>,
+ NoCapture<ArgIndex<0>>]>;
def int_masked_store:
DefaultAttrsIntrinsic<[],
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 59a213b47825a..e2f484fada9ed 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -7151,6 +7151,7 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
switch (IID) {
case Intrinsic::masked_load:
case Intrinsic::masked_gather: {
+
Value *MaskArg = Args[1];
Value *PassthruArg = Args[2];
// If the mask is all zeros or undef, the "passthru" argument is the result.
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 7cc1980d24c33..cd757bcb82bda 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6275,10 +6275,13 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
Check(Call.getType()->isVectorTy(), "masked_load: must return a vector",
Call);
- Value *Mask = Call.getArgOperand(1);
- Value *PassThru = Call.getArgOperand(2);
+ ConstantInt *Alignment = cast<ConstantInt>(Call.getArgOperand(1));
+ Value *Mask = Call.getArgOperand(2);
+ Value *PassThru = Call.getArgOperand(3);
Check(Mask->getType()->isVectorTy(), "masked_load: mask must be vector",
Call);
+ Check(Alignment->getValue().isPowerOf2(),
+ "masked_load: alignment must be a power of 2", Call);
Check(PassThru->getType() == Call.getType(),
"masked_load: pass through and return type must match", Call);
Check(cast<VectorType>(Mask->getType())->getElementCount() ==
``````````
</details>
https://github.com/llvm/llvm-project/pull/169464
More information about the llvm-commits
mailing list