[llvm] [llvm] Adding scalarization of `llvm.vector.insert` (PR #71614)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 7 17:06:21 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: Rob Suderman (rsuderman)

<details>
<summary>Changes</summary>

Needed handling the case of scalarizing operands of subvector insertion.

---
Full diff: https://github.com/llvm/llvm-project/pull/71614.diff


4 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+26-2) 
- (modified) llvm/test/CodeGen/AArch64/aarch64-neon-v1i1-setcc.ll (+9) 
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+72) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index f85c1296cdce856..5651c6e9b218447 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -809,6 +809,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   // Vector Operand Scalarization: <1 x ty> -> ty.
   bool ScalarizeVectorOperand(SDNode *N, unsigned OpNo);
   SDValue ScalarizeVecOp_BITCAST(SDNode *N);
+  SDValue ScalarizeVecOp_INSERT_SUBVECTOR(SDNode *N, unsigned OpNo);
   SDValue ScalarizeVecOp_UnaryOp(SDNode *N);
   SDValue ScalarizeVecOp_UnaryOp_StrictFP(SDNode *N);
   SDValue ScalarizeVecOp_CONCAT_VECTORS(SDNode *N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index a1a9f0f0615cbc7..9f59ae333403d2c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -675,6 +675,9 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::BITCAST:
     Res = ScalarizeVecOp_BITCAST(N);
     break;
+  case ISD::INSERT_SUBVECTOR:
+    Res = ScalarizeVecOp_INSERT_SUBVECTOR(N, OpNo);
+    break;
   case ISD::ANY_EXTEND:
   case ISD::ZERO_EXTEND:
   case ISD::SIGN_EXTEND:
@@ -766,6 +769,24 @@ SDValue DAGTypeLegalizer::ScalarizeVecOp_BITCAST(SDNode *N) {
                      N->getValueType(0), Elt);
 }
 
+/// If the value to subvector is a vector that needs to be scalarized, it must
+/// be <1 x ty>. Return the element instead.
+SDValue DAGTypeLegalizer::ScalarizeVecOp_INSERT_SUBVECTOR(SDNode *N,
+                                                          unsigned OpNo) {
+  // If the destination vector is unary, we can just return the source vector
+  auto src = GetScalarizedVector(N->getOperand(1));
+  if (OpNo == 0) {
+    return src;
+  }
+
+  auto dest = N->getOperand(0);
+  auto idx = N->getOperand(2);
+  return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(N), N->getValueType(0), dest,
+                     src, idx);
+
+  return GetScalarizedVector(src);
+}
+
 /// If the input is a vector that needs to be scalarized, it must be <1 x ty>.
 /// Do the operation on the element instead.
 SDValue DAGTypeLegalizer::ScalarizeVecOp_UnaryOp(SDNode *N) {
@@ -5891,8 +5912,11 @@ SDValue DAGTypeLegalizer::WidenVecRes_SETCC(SDNode *N) {
     InOp1 = GetWidenedVector(InOp1);
     InOp2 = GetWidenedVector(InOp2);
   } else {
-    InOp1 = DAG.WidenVector(InOp1, SDLoc(N));
-    InOp2 = DAG.WidenVector(InOp2, SDLoc(N));
+    do {
+      InOp1 = DAG.WidenVector(InOp1, SDLoc(N));
+      InOp2 = DAG.WidenVector(InOp2, SDLoc(N));
+    } while (ElementCount::isKnownLT(
+        InOp1.getValueType().getVectorElementCount(), WidenEC));
   }
 
   // Assume that the input and output will be widen appropriately.  If not,
diff --git a/llvm/test/CodeGen/AArch64/aarch64-neon-v1i1-setcc.ll b/llvm/test/CodeGen/AArch64/aarch64-neon-v1i1-setcc.ll
index c932253049e239f..91762cb898897c4 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-neon-v1i1-setcc.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-neon-v1i1-setcc.ll
@@ -67,3 +67,12 @@ if.then:
 if.end:
   ret i32 1;
 }
+
+define dso_local <1 x half> @cmp_select(<1 x half> %i105, <1 x half> %in) {
+; CHECK-LABEL: @cmp_select
+; CHECL: fcmge
+newFuncRoot:
+  %i179 = fcmp uno <1 x half> %i105, zeroinitializer
+  %i180 = select <1 x i1> %i179, <1 x half> %in, <1 x half> %i105
+  ret <1 x half> %i180
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2cadd4e0d2911a6..0fbdddeb12950f2 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3121,6 +3121,38 @@ gentbl_cc_library(
     ],
 )
 
+gentbl_cc_library(
+    name = "MeshShardingInterfaceIncGen",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td",
+    deps = [":OpBaseTdFiles"],
+)
+
+cc_library(
+    name = "MeshShardingInterface",
+    srcs = ["lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp"],
+    hdrs = ["include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"],
+    includes = ["include"],
+    deps = [
+        ":DialectUtils",
+        ":IR",
+        ":MeshDialect",
+        ":MeshShardingInterfaceIncGen",
+        ":Support",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "MeshDialect",
     srcs = ["lib/Dialect/Mesh/IR/MeshOps.cpp"],
@@ -3136,6 +3168,40 @@ cc_library(
     ],
 )
 
+gentbl_cc_library(
+    name = "MeshTransformsPassIncGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-pass-decls",
+                "-name=Mesh",
+            ],
+            "include/mlir/Dialect/Mesh/Transforms/Passes.h.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Mesh/Transforms/Passes.td",
+    deps = [":PassBaseTdFiles"],
+)
+
+cc_library(
+    name = "MeshTransforms",
+    srcs = glob([
+        "lib/Dialect/Mesh/Transforms/*.cpp",
+        "lib/Dialect/Mesh/Transforms/*.h",
+    ]),
+    hdrs = glob(["include/mlir/Dialect/Mesh/Transforms/*.h"]),
+    includes = ["include"],
+    deps = [
+        ":FuncDialect",
+        ":MeshDialect",
+        ":MeshShardingInterface",
+        ":MeshTransformsPassIncGen",
+        ":Pass",
+        "//llvm:Support",
+    ],
+)
+
 ##---------------------------------------------------------------------------##
 # NVGPU dialect.
 ##---------------------------------------------------------------------------##
@@ -5182,6 +5248,7 @@ cc_library(
         ":ROCDLTarget",
         ":ROCDLToLLVMIRTranslation",
         ":SCFDialect",
+        ":SPIRVDialect",
         ":SerializeToCubin_stub",
         ":SideEffectInterfaces",
         ":Support",
@@ -5618,6 +5685,7 @@ cc_library(
     deps = [
         ":ArithToSPIRV",
         ":ConversionPassIncGen",
+        ":FuncDialect",
         ":FuncToSPIRV",
         ":GPUDialect",
         ":IR",
@@ -6437,6 +6505,7 @@ cc_library(
         ":CommonFolders",
         ":ControlFlowInterfaces",
         ":FunctionInterfaces",
+        ":GPUDialect",
         ":IR",
         ":InferTypeOpInterface",
         ":Parser",
@@ -8632,6 +8701,7 @@ cc_library(
         ":MemRefTransformOps",
         ":MemRefTransforms",
         ":MeshDialect",
+        ":MeshTransforms",
         ":NVGPUDialect",
         ":NVGPUPassIncGen",
         ":NVGPUToNVVM",
@@ -11051,6 +11121,8 @@ cc_library(
         ":IR",
         ":InferTypeOpInterface",
         ":LoopLikeInterface",
+        ":MeshDialect",
+        ":MeshShardingInterface",
         ":Pass",
         ":QuantOps",
         ":Support",

``````````

</details>


https://github.com/llvm/llvm-project/pull/71614


More information about the llvm-commits mailing list