[PATCH] D72688: [mlir] Fix translation of splat constants to LLVM IR

Alex Zinenko via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 02:33:08 PST 2020


ftynse created this revision.
Herald added subscribers: llvm-commits, liufengdb, aartbik, lucyrfox, mgester, arpith-jacob, nicolasvasilache, antiagainst, shauheen, burmako, jpienaar, rriddle, mehdi_amini.
Herald added a project: LLVM.
ftynse added reviewers: nicolasvasilache, aartbik.

When converting splat constants for nested sequential LLVM IR types wrapped in
MLIR, the constant conversion was erroneously assuming it was always possible
to recursively construct a constant of a sequential type given only one value.
Instead, wait until all sequential types are unpacked recursively before
constructing a scalar constant and wrapping it into the surrounding sequential
type.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D72688

Files:
  mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
  mlir/test/Target/llvmir.mlir


Index: mlir/test/Target/llvmir.mlir
===================================================================
--- mlir/test/Target/llvmir.mlir
+++ mlir/test/Target/llvmir.mlir
@@ -804,6 +804,34 @@
   llvm.return %1 : !llvm<"<4 x float>">
 }
 
+// CHECK-LABEL: @vector_splat_1d
+llvm.func @vector_splat_1d() -> !llvm<"<4 x float>"> {
+  // CHECK: ret <4 x float> zeroinitializer
+  %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>">
+  llvm.return %0 : !llvm<"<4 x float>">
+}
+
+// CHECK-LABEL: @vector_splat_2d
+llvm.func @vector_splat_2d() -> !llvm<"[4 x <16 x float>]"> {
+  // CHECK: ret [4 x <16 x float>] zeroinitializer
+  %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4x16xf32>) : !llvm<"[4 x <16 x float>]">
+  llvm.return %0 : !llvm<"[4 x <16 x float>]">
+}
+
+// CHECK-LABEL: @vector_splat_3d
+llvm.func @vector_splat_3d() -> !llvm<"[4 x [16 x <4 x float>]]"> {
+  // CHECK: ret [4 x [16 x <4 x float>]] zeroinitializer
+  %0 = llvm.mlir.constant(dense<0.000000e+00> : vector<4x16x4xf32>) : !llvm<"[4 x [16 x <4 x float>]]">
+  llvm.return %0 : !llvm<"[4 x [16 x <4 x float>]]">
+}
+
+// CHECK-LABEL: @vector_splat_nonzero
+llvm.func @vector_splat_nonzero() -> !llvm<"<4 x float>"> {
+  // CHECK: ret <4 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
+  %0 = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>">
+  llvm.return %0 : !llvm<"<4 x float>">
+}
+
 // CHECK-LABEL: @ops
 llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm.i32) -> !llvm<"{ float, i32 }"> {
 // CHECK-NEXT: fsub float %0, %1
Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
===================================================================
--- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -49,7 +49,14 @@
     auto *sequentialType = cast<llvm::SequentialType>(llvmType);
     auto elementType = sequentialType->getElementType();
     uint64_t numElements = sequentialType->getNumElements();
-    auto *child = getLLVMConstant(elementType, splatAttr.getSplatValue(), loc);
+    // Splat value is a scalar. Extract it only if the element type is not
+    // another sequence type. The recursion terminates because each step removes
+    // one outer sequential type.
+    llvm::Constant *child = getLLVMConstant(
+        elementType,
+        isa<llvm::SequentialType>(elementType) ? splatAttr
+                                               : splatAttr.getSplatValue(),
+        loc);
     if (llvmType->isVectorTy())
       return llvm::ConstantVector::getSplat(numElements, child);
     if (llvmType->isArrayTy()) {


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D72688.237895.patch
Type: text/x-patch
Size: 2699 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200114/03819b92/attachment.bin>


More information about the llvm-commits mailing list