[clang] [CIR] Upstream shift operators for VectorType (PR #139465)

via cfe-commits cfe-commits at lists.llvm.org
Sun May 11 12:55:40 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: Amr Hesham (AmrDeveloper)

<details>
<summary>Changes</summary>

This change adds support for shift ops for VectorType

Issue https://github.com/llvm/llvm-project/issues/136487

---
Full diff: https://github.com/llvm/llvm-project/pull/139465.diff


6 Files Affected:

- (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+8-7) 
- (modified) clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td (+19) 
- (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+2-3) 
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+9-6) 
- (modified) clang/test/CIR/CodeGen/vector-ext.cpp (+65) 
- (modified) clang/test/CIR/CodeGen/vector.cpp (+65) 


``````````diff
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 7aff5edb88167..b0e593b011109 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1401,18 +1401,19 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
     The `cir.shift` operation performs a bitwise shift, either to the left or to
     the right, based on the first operand. The second operand specifies the
     value to be shifted, and the third operand determines the number of
-    positions by which the shift is applied. Both the second and third operands
-    are required to be integers.
+    positions by which the shift is applied, they must be either all vector of
+    integer type, or all integer type. If they are vectors, each vector element of
+    the shift target is shifted by the corresponding shift amount in
+    the shift amount vector.
 
     ```mlir
-    %7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
+    %res = cir.shift(left, %lhs : !u64i, %amount : !s32i) -> !u64i
+    %new_vec = cir.shift(left, %lhs : !cir.vector<2 x !s32i>, %rhs : !cir.vector<2 x !s32i>) -> !cir.vector<2 x !s32i>
     ```
   }];
 
-  // TODO(cir): Support vectors. CIR_IntType -> CIR_AnyIntOrVecOfInt. Also
-  // update the description above.
-  let results = (outs CIR_IntType:$result);
-  let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
+  let results = (outs CIR_AnyIntOrVecOfInt:$result);
+  let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
                        UnitAttr:$isShiftleft);
 
   let assemblyFormat = [{
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index 00f67e2a03a25..902d6535ff717 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -174,4 +174,23 @@ def CIR_PtrToVoidPtrType
         "$_builder.getType<" # cppType # ">("
         "cir::VoidType::get($_builder.getContext())))">;
 
+//===----------------------------------------------------------------------===//
+// Vector Type predicates
+//===----------------------------------------------------------------------===//
+
+// Vector of integral type
+def IntegerVector : Type<
+    And<[
+      CPred<"::mlir::isa<::cir::VectorType>($_self)">,
+      CPred<"::mlir::isa<::cir::IntType>("
+            "::mlir::cast<::cir::VectorType>($_self).getElementType())">,
+      CPred<"::mlir::cast<::cir::IntType>("
+            "::mlir::cast<::cir::VectorType>($_self).getElementType())"
+            ".isFundamental()">
+    ]>, "!cir.vector of !cir.int"> {
+}
+
+// Any Integer or Vector of Integer Constraints
+def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
+
 #endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index b131edaf403ed..2f7e3496d55a5 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1297,9 +1297,8 @@ OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
 LogicalResult cir::ShiftOp::verify() {
   mlir::Operation *op = getOperation();
   mlir::Type resType = getResult().getType();
-  assert(!cir::MissingFeatures::vectorType());
-  bool isOp0Vec = false;
-  bool isOp1Vec = false;
+  const bool isOp0Vec = mlir::isa<cir::VectorType>(op->getOperand(0).getType());
+  const bool isOp1Vec = mlir::isa<cir::VectorType>(op->getOperand(1).getType());
   if (isOp0Vec != isOp1Vec)
     return emitOpError() << "input types cannot be one vector and one scalar";
   if (isOp1Vec && op->getOperand(1).getType() != resType) {
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5986655ababe9..1951d2e3a3b79 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1376,16 +1376,17 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
   auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType());
 
   // Operands could also be vector type
-  assert(!cir::MissingFeatures::vectorType());
+  auto cirAmtVTy = mlir::dyn_cast<cir::VectorType>(op.getAmount().getType());
+  auto cirValVTy = mlir::dyn_cast<cir::VectorType>(op.getValue().getType());
   mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
   mlir::Value amt = adaptor.getAmount();
   mlir::Value val = adaptor.getValue();
 
-  // TODO(cir): Assert for vector types
-  assert((cirValTy && cirAmtTy) &&
+  assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) &&
          "shift input type must be integer or vector type, otherwise NYI");
 
-  assert((cirValTy == op.getType()) && "inconsistent operands' types NYI");
+  assert((cirValTy == op.getType() || cirValVTy == op.getType()) &&
+         "inconsistent operands' types NYI");
 
   // Ensure shift amount is the same type as the value. Some undefined
   // behavior might occur in the casts below as per [C99 6.5.7.3].
@@ -1399,8 +1400,10 @@ mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
   if (op.getIsShiftleft()) {
     rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
   } else {
-    assert(!cir::MissingFeatures::vectorType());
-    bool isUnsigned = !cirValTy.isSigned();
+    const bool isUnsigned =
+        cirValTy
+            ? !cirValTy.isSigned()
+            : !mlir::cast<cir::IntType>(cirValVTy.getElementType()).isSigned();
     if (isUnsigned)
       rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
     else
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 0756497bf6b96..4b3a271587f67 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -213,3 +213,68 @@ void foo4() {
 // OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
 // OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
 // OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo9() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 b = {5, 6, 7, 8};
+
+  vi4 shl = a << b;
+  vi4 shr = a >> b;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
+// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
+// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
+// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 530018108c6d9..cf23a97ac61f2 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -201,3 +201,68 @@ void foo4() {
 // OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
 // OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
 // OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
+
+void foo9() {
+  vi4 a = {1, 2, 3, 4};
+  vi4 b = {5, 6, 7, 8};
+
+  vi4 shl = a << b;
+  vi4 shr = a >> b;
+}
+
+// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
+// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b", init]
+// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
+// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shr", init]
+// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
+// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
+// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
+// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
+// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[CONST_5:.*]] = cir.const #cir.int<5> : !s32i
+// CIR: %[[CONST_6:.*]] = cir.const #cir.int<6> : !s32i
+// CIR: %[[CONST_7:.*]] = cir.const #cir.int<7> : !s32i
+// CIR: %[[CONST_8:.*]] = cir.const #cir.int<8> : !s32i
+// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_5]], %[[CONST_6]], %[[CONST_7]], %[[CONST_8]] :
+// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
+// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
+// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[TMP_B]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
+// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
+
+// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
+// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// LLVM: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// LLVM: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
+
+// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
+// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
+// OGCG: store <4 x i32> <i32 5, i32 6, i32 7, i32 8>, ptr %[[VEC_B]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
+// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
+// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
+// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
+// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

``````````

</details>


https://github.com/llvm/llvm-project/pull/139465


More information about the cfe-commits mailing list