[clang] [CIR] Upstream splat op for VectorType (PR #139827)
Amr Hesham via cfe-commits
cfe-commits at lists.llvm.org
Tue Jun 10 14:34:41 PDT 2025
https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/139827
>From 0d1a56f8020a1a676977df3f848eac896242e81b Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Tue, 13 May 2025 21:35:06 +0200
Subject: [PATCH 1/7] [CIR] Upstream splat op for VectorType
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 32 ++++++++++
clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp | 8 +++
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 56 +++++++++++++++++
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 10 +++
clang/test/CIR/CodeGen/vector-ext.cpp | 63 +++++++++++++++++++
clang/test/CIR/CodeGen/vector.cpp | 63 +++++++++++++++++++
clang/test/CIR/IR/vector.cir | 34 ++++++++++
7 files changed, 266 insertions(+)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 8579f7066e4fe..87fa03291af89 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2275,6 +2275,38 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// VecSplatOp
+//===----------------------------------------------------------------------===//
+
+def VecSplatOp : CIR_Op<"vec.splat", [Pure,
+ TypesMatchWith<"type of 'value' matches element type of 'result'", "result",
+ "value", "cast<VectorType>($_self).getElementType()">]> {
+
+ let summary = "Convert a scalar into a vector";
+ let description = [{
+ The `cir.vec.splat` operation creates a vector value from a scalar value.
+ All elements of the vector have the same value, that of the given scalar.
+
+ It's a separate operation from `cir.vec.create` because more
+ efficient LLVM IR can be generated for it, and because some optimization and
+ analysis passes can benefit from knowing that all elements of the vector
+ have the same value.
+
+ ```mlir
+ %value = cir.const #cir.int<3> : !s32i
+ %value_vec = cir.vec.splat %value : !s32i, !cir.vector<4 x !s32i>
+ ```
+ }];
+
+ let arguments = (ins CIR_AnyType:$value);
+ let results = (outs CIR_VectorType:$result);
+
+ let assemblyFormat = [{
+ $value `:` type($value) `,` qualified(type($result)) attr-dict
+ }];
+}
+
//===----------------------------------------------------------------------===//
// BaseClassAddrOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 481eb492d1875..30d231e2c61de 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1780,6 +1780,14 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
cgf.convertType(destTy));
}
+ case CK_VectorSplat: {
+ // Create a vector object and fill all elements with the same scalar value.
+ assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
+ return builder.create<cir::VecSplatOp>(
+ cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy),
+ Visit(subExpr));
+ }
+
default:
cgf.getCIRGenModule().errorNYI(subExpr->getSourceRange(),
"CastExpr: ", ce->getCastKindName());
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 4fdf8f9ec2695..7b1ced537047e 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1803,6 +1803,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecExtractOpLowering,
CIRToLLVMVecInsertOpLowering,
CIRToLLVMVecCmpOpLowering,
+ CIRToLLVMVecSplatOpLowering,
CIRToLLVMVecShuffleOpLowering,
CIRToLLVMVecShuffleDynamicOpLowering,
CIRToLLVMVecTernaryOpLowering
@@ -1956,6 +1957,61 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
+ cir::VecSplatOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ // Vector splat can be implemented with an `insertelement` and a
+ // `shufflevector`, which is better than an `insertelement` for each
+ // element in the vector. Start with an undef vector. Insert the value into
+ // the first element. Then use a `shufflevector` with a mask of all 0 to
+ // fill out the entire vector with that value.
+ const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
+ const mlir::Type llvmTy = typeConverter->convertType(vecTy);
+ const mlir::Location loc = op.getLoc();
+ const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
+
+ const mlir::Value elementValue = adaptor.getValue();
+ if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
+ // If the splat value is poison, then we can just use poison value
+ // for the entire vector.
+ rewriter.replaceOp(op, poison);
+ return mlir::success();
+ }
+
+ if (auto constValue =
+ dyn_cast<mlir::LLVM::ConstantOp>(elementValue.getDefiningOp())) {
+ if (auto intAttr = dyn_cast<mlir::IntegerAttr>(constValue.getValue())) {
+ mlir::DenseIntElementsAttr denseVec = mlir::DenseIntElementsAttr::get(
+ mlir::cast<mlir::ShapedType>(llvmTy), intAttr.getValue());
+
+ const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, denseVec.getType(), denseVec);
+ rewriter.replaceOp(op, indexValue);
+ return mlir::success();
+ }
+
+ if (auto fpAttr = dyn_cast<mlir::FloatAttr>(constValue.getValue())) {
+ mlir::DenseFPElementsAttr denseVec = mlir::DenseFPElementsAttr::get(
+ mlir::cast<mlir::ShapedType>(llvmTy), fpAttr.getValue());
+
+ const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, denseVec.getType(), denseVec);
+ rewriter.replaceOp(op, indexValue);
+ return mlir::success();
+ }
+ }
+
+ const mlir::Value indexValue =
+ rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
+ const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
+ loc, poison, elementValue, indexValue);
+ const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
+ const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
+ loc, oneElement, poison, zeroValues);
+ rewriter.replaceOp(op, shuffled);
+ return mlir::success();
+}
+
mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite(
cir::VecShuffleOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 22d8a1e7c22e0..2eda568c84bdb 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -367,6 +367,16 @@ class CIRToLLVMVecCmpOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMVecSplatOpLowering
+ : public mlir::OpConversionPattern<cir::VecSplatOp> {
+public:
+ using mlir::OpConversionPattern<cir::VecSplatOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::VecSplatOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
class CIRToLLVMVecShuffleOpLowering
: public mlir::OpConversionPattern<cir::VecShuffleOp> {
public:
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index e1814f216f6b9..c336c67bcf09b 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -1092,6 +1092,69 @@ void foo17() {
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+void foo18() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 shl = a << 3;
+
+ uvi4 b = {1u, 2u, 3u, 4u};
+ uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
void foo19() {
vi4 a;
vi4 b;
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 4f116faa7a1ac..23e91724dc0f3 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -1071,6 +1071,69 @@ void foo17() {
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
+void foo18() {
+ vi4 a = {1, 2, 3, 4};
+ vi4 shl = a << 3;
+
+ uvi4 b = {1u, 2u, 3u, 4u};
+ uvi4 shr = b >> 3u;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store{{.*}} %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store{{.*}} %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
+// CIR: cir.store{{.*}} %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
+// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
+// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
+// CIR: cir.store{{.*}} %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
void foo19() {
vi4 a;
vi4 b;
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index a455acf92ab6f..74b5d48588f36 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -187,4 +187,38 @@ cir.func @vector_shuffle_dynamic_test() {
// CHECK: cir.return
// CHECK: }
+cir.func @vector_splat_test() {
+ %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+ %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+ %2 = cir.const #cir.int<1> : !s32i
+ %3 = cir.const #cir.int<2> : !s32i
+ %4 = cir.const #cir.int<3> : !s32i
+ %5 = cir.const #cir.int<4> : !s32i
+ %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+ cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+ %8 = cir.const #cir.int<3> : !s32i
+ %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+ %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+ cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+ cir.return
+}
+
+// CHECK: cir.func @vector_splat_test() {
+// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CHECK: %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CHECK: %2 = cir.const #cir.int<1> : !s32i
+// CHECK: %3 = cir.const #cir.int<2> : !s32i
+// CHECK: %4 = cir.const #cir.int<3> : !s32i
+// CHECK: %5 = cir.const #cir.int<4> : !s32i
+// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK: %8 = cir.const #cir.int<3> : !s32i
+// CHECK: %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
+// CHECK: %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CHECK: cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK: cir.return
+// CHECK: }
+
}
>From b9c65be3c632abf40fb2722c1c061325606ae5c6 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Thu, 29 May 2025 18:05:18 +0200
Subject: [PATCH 2/7] Update test files for load and store with allign
---
clang/test/CIR/CodeGen/vector-ext.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index c336c67bcf09b..965c44c9461a8 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -990,6 +990,7 @@ void foo14() {
// OGCG: %[[TMP_B:.*]] = load <4 x float>, ptr %[[VEC_B]], align 16
// OGCG: %[[GE:.*]] = fcmp oge <4 x float> %[[TMP_A]], %[[TMP_B]]
// OGCG: %[[RES:.*]] = sext <4 x i1> %[[GE]] to <4 x i32>
+// OGCG: store <4 x i32> %[[RES]], ptr {{.*}}, align 16
void foo15() {
vi4 a;
>From 4c5ce0d7fb7be0862c2b657acd4c4cb5d3206ad1 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Tue, 3 Jun 2025 19:22:19 +0200
Subject: [PATCH 3/7] Address code review comment
---
clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 7b1ced537047e..5a3fd620f2a2a 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -2006,9 +2006,8 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
loc, poison, elementValue, indexValue);
const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
- const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
- loc, oneElement, poison, zeroValues);
- rewriter.replaceOp(op, shuffled);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement,
+ poison, zeroValues);
return mlir::success();
}
>From e86d571ec4a55c1de7e24f6fde5cc45c003fdca4 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Thu, 5 Jun 2025 12:20:37 +0200
Subject: [PATCH 4/7] Mirror VecSplatOp arguments type constraint from clangir
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 87fa03291af89..13b0e38069de1 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2299,7 +2299,7 @@ def VecSplatOp : CIR_Op<"vec.splat", [Pure,
```
}];
- let arguments = (ins CIR_AnyType:$value);
+ let arguments = (ins CIR_VectorElementType:$value);
let results = (outs CIR_VectorType:$result);
let assemblyFormat = [{
>From 91bd7cf23932beecd12ec76b48a45f5eae29fa04 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Thu, 5 Jun 2025 16:06:12 +0200
Subject: [PATCH 5/7] Use replaceOpWithNewOp and remove extra cast
---
.../lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5a3fd620f2a2a..88ba6826c9ddd 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1965,7 +1965,7 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
// element in the vector. Start with an undef vector. Insert the value into
// the first element. Then use a `shufflevector` with a mask of all 0 to
// fill out the entire vector with that value.
- const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
+ const cir::VectorType vecTy = op.getType();
const mlir::Type llvmTy = typeConverter->convertType(vecTy);
const mlir::Location loc = op.getLoc();
const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
@@ -1983,20 +1983,16 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(constValue.getValue())) {
mlir::DenseIntElementsAttr denseVec = mlir::DenseIntElementsAttr::get(
mlir::cast<mlir::ShapedType>(llvmTy), intAttr.getValue());
-
- const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>(
- loc, denseVec.getType(), denseVec);
- rewriter.replaceOp(op, indexValue);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
+ op, denseVec.getType(), denseVec);
return mlir::success();
}
if (auto fpAttr = dyn_cast<mlir::FloatAttr>(constValue.getValue())) {
mlir::DenseFPElementsAttr denseVec = mlir::DenseFPElementsAttr::get(
mlir::cast<mlir::ShapedType>(llvmTy), fpAttr.getValue());
-
- const mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>(
- loc, denseVec.getType(), denseVec);
- rewriter.replaceOp(op, indexValue);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
+ op, denseVec.getType(), denseVec);
return mlir::success();
}
}
>From 93e173749a79d39567f5247be73959bb1984fb62 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Mon, 9 Jun 2025 21:14:28 +0200
Subject: [PATCH 6/7] Remove const qualifiers from mlir value
---
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 88ba6826c9ddd..1642d10d427b5 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1965,12 +1965,12 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
// element in the vector. Start with an undef vector. Insert the value into
// the first element. Then use a `shufflevector` with a mask of all 0 to
// fill out the entire vector with that value.
- const cir::VectorType vecTy = op.getType();
- const mlir::Type llvmTy = typeConverter->convertType(vecTy);
- const mlir::Location loc = op.getLoc();
- const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
+ cir::VectorType vecTy = op.getType();
+ mlir::Type llvmTy = typeConverter->convertType(vecTy);
+ mlir::Location loc = op.getLoc();
+ mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
- const mlir::Value elementValue = adaptor.getValue();
+ mlir::Value elementValue = adaptor.getValue();
if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
// If the splat value is poison, then we can just use poison value
// for the entire vector.
@@ -1997,11 +1997,11 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
}
}
- const mlir::Value indexValue =
+ mlir::Value indexValue =
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
- const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
+ mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
loc, poison, elementValue, indexValue);
- const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
+ SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement,
poison, zeroValues);
return mlir::success();
>From 28f80c63ebe997e3653790873211db584391e9c6 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Tue, 10 Jun 2025 23:34:05 +0200
Subject: [PATCH 7/7] Update vector_splat_test to use pattern matching
---
clang/test/CIR/IR/vector.cir | 29 ++++++++++++++---------------
1 file changed, 14 insertions(+), 15 deletions(-)
diff --git a/clang/test/CIR/IR/vector.cir b/clang/test/CIR/IR/vector.cir
index 74b5d48588f36..f23f5de9692de 100644
--- a/clang/test/CIR/IR/vector.cir
+++ b/clang/test/CIR/IR/vector.cir
@@ -205,20 +205,19 @@ cir.func @vector_splat_test() {
}
// CHECK: cir.func @vector_splat_test() {
-// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
-// CHECK: %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
-// CHECK: %2 = cir.const #cir.int<1> : !s32i
-// CHECK: %3 = cir.const #cir.int<2> : !s32i
-// CHECK: %4 = cir.const #cir.int<3> : !s32i
-// CHECK: %5 = cir.const #cir.int<4> : !s32i
-// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
-// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
-// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
-// CHECK: %8 = cir.const #cir.int<3> : !s32i
-// CHECK: %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
-// CHECK: %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
-// CHECK: cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
-// CHECK: cir.return
-// CHECK: }
+// CHECK-NEXT: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CHECK-NEXT: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CHECK-NEXT: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CHECK-NEXT: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CHECK-NEXT: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK-NEXT: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CHECK-NEXT: %[[SPLAT_VAL:.*]] = cir.const #cir.int<3> : !s32i
+// CHECK-NEXT: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SPLAT_VAL]] : !s32i, !cir.vector<4 x !s32i>
+// CHECK-NEXT: %[[SHL:.*]] = cir.shift(left, %[[TMP]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CHECK-NEXT: cir.store %[[SHL]], %[[SHL_RES:.*]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CHECK-NEXT: cir.return
}
More information about the cfe-commits
mailing list