[Mlir-commits] [mlir] 7c56458 - [mlir] Fix scalable type translation in splat element attr

Javier Setoain llvmlistbot at llvm.org
Thu Jan 13 03:22:04 PST 2022


Author: Javier Setoain
Date: 2022-01-13T11:14:41Z
New Revision: 7c56458616602b91a73665b259e840ca767eae15

URL: https://github.com/llvm/llvm-project/commit/7c56458616602b91a73665b259e840ca767eae15
DIFF: https://github.com/llvm/llvm-project/commit/7c56458616602b91a73665b259e840ca767eae15.diff

LOG: [mlir] Fix scalable type translation in splat element attr

LLVM Dialect Constant Op translations assume that if the attribute is a
vector, it's a fixed length one, generating an invalid translation for
constant scalable vector initializations.

Differential Revision: https://reviews.llvm.org/D117125

Added: 
    

Modified: 
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Target/LLVMIR/llvmir-types.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ebd35e1aae66e..57733a68fc24c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -239,6 +239,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
   if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
     llvm::Type *elementType;
     uint64_t numElements;
+    bool isScalable = false;
     if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
       elementType = arrayTy->getElementType();
       numElements = arrayTy->getNumElements();
@@ -248,6 +249,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
       elementType = sVectorTy->getElementType();
       numElements = sVectorTy->getMinNumElements();
+      isScalable = true;
     } else {
       llvm_unreachable("unrecognized constant vector type");
     }
@@ -265,7 +267,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
       return nullptr;
     if (llvmType->isVectorTy())
       return llvm::ConstantVector::getSplat(
-          llvm::ElementCount::get(numElements, /*Scalable=*/false), child);
+          llvm::ElementCount::get(numElements, /*Scalable=*/isScalable), child);
     if (llvmType->isArrayTy()) {
       auto *arrayType = llvm::ArrayType::get(elementType, numElements);
       SmallVector<llvm::Constant *, 8> constants(numElements, child);

diff  --git a/mlir/test/Target/LLVMIR/llvmir-types.mlir b/mlir/test/Target/LLVMIR/llvmir-types.mlir
index 2bbe5f34c1bfc..da3b395c3ab10 100644
--- a/mlir/test/Target/LLVMIR/llvmir-types.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-types.mlir
@@ -90,6 +90,8 @@ llvm.func @return_ppi8_42_9() -> !llvm.ptr<ptr<i8, 42>, 9>
 llvm.func @return_v4_i32() -> vector<4xi32>
 // CHECK: declare <4 x float> @return_v4_float()
 llvm.func @return_v4_float() -> vector<4xf32>
+// CHECK: declare <vscale x 4 x float> @return_vs_4_float()
+llvm.func @return_vs_4_float() -> vector<[4]xf32>
 // CHECK: declare <vscale x 4 x i32> @return_vs_4_i32()
 llvm.func @return_vs_4_i32() -> !llvm.vec<?x4 x i32>
 // CHECK: declare <vscale x 8 x half> @return_vs_8_half()

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index c4a5434144180..69e881228d14b 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -907,6 +907,13 @@ llvm.func @vector_splat_1d() -> vector<4xf32> {
   llvm.return %0 : vector<4xf32>
 }
 
+// CHECK-LABEL: @vector_splat_1d_scalable
+llvm.func @vector_splat_1d_scalable() -> vector<[4]xf32> {
+  // CHECK: ret <vscale x 4 x float> zeroinitializer
+  %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+  llvm.return %0 : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @vector_splat_2d
 llvm.func @vector_splat_2d() -> !llvm.array<4 x vector<16 x f32>> {
   // CHECK: ret [4 x <16 x float>] zeroinitializer
@@ -928,6 +935,13 @@ llvm.func @vector_splat_nonzero() -> vector<4xf32> {
   llvm.return %0 : vector<4xf32>
 }
 
+// CHECK-LABEL: @vector_splat_nonzero_scalable
+llvm.func @vector_splat_nonzero_scalable() -> vector<[4]xf32> {
+  // CHECK: ret <vscale x 4 x float> shufflevector (<vscale x 4 x float> insertelement (<vscale x 4 x float> poison, float 1.000000e+00, i32 0), <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer)
+  %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+  llvm.return %0 : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @ops
 llvm.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32) -> !llvm.struct<(f32, i32)> {
 // CHECK-NEXT: fsub float %0, %1


        


More information about the Mlir-commits mailing list