[Mlir-commits] [mlir] [mlir][llvm] Use zeroinitializer for TargetExtType (PR #66510)

Lukas Sommer llvmlistbot at llvm.org
Fri Sep 15 06:26:35 PDT 2023


https://github.com/sommerlukas created https://github.com/llvm/llvm-project/pull/66510

Use the recently introduced llvm.mlir.zero operation for values with LLVM target extension type. Replaces the previous workaround that uses a single zero-valued integer attribute constant operation.

>From e8fb4e1efe2ba14576f1c43b746643bc93c1c52a Mon Sep 17 00:00:00 2001
From: Lukas Sommer <lukas.sommer at codeplay.com>
Date: Fri, 15 Sep 2023 09:42:19 +0100
Subject: [PATCH] [mlir][llvm] Use zeroinitializer for TargetExtType

Use the recently introduced llvm.mlir.zero operation for values with
LLVM target extension type. Replaces the previous workaround that uses a
single zero-valued integer attribute constant operation.

Signed-off-by: Lukas Sommer <lukas.sommer at codeplay.com>
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  1 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 35 ++++++++++---------
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 13 ++-----
 mlir/test/Dialect/LLVMIR/global.mlir          | 14 +++++---
 mlir/test/Dialect/LLVMIR/invalid.mlir         | 10 +++---
 .../Target/LLVMIR/Import/target-ext-type.ll   |  4 +--
 mlir/test/Target/LLVMIR/target-ext-type.mlir  |  4 +--
 7 files changed, 42 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 57dce72e102e7a4..726349597aa2556 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1530,6 +1530,7 @@ def LLVM_ZeroOp
   let results = (outs LLVM_Type:$res);
   let builders = [LLVM_OneResultOpBuilder];
   let assemblyFormat = "attr-dict `:` type($res)";
+  let hasVerifier = 1;
 }
 
 def LLVM_ConstantOp
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c3575d299b3888a..9b7d17ef77ea272 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2051,13 +2051,8 @@ LogicalResult GlobalOp::verify() {
              << "this target extension type cannot be used in a global";
 
     if (Attribute value = getValueOrNull()) {
-      // Only a single, zero integer attribute (=zeroinitializer) is allowed for
-      // a global value with TargetExtType.
-      // TODO: Replace with 'zeroinitializer' once there is a dedicated
-      // zeroinitializer operation in the LLVM dialect.
-      if (!isa<IntegerAttr>(value) || !isZeroAttribute(value))
-        return emitOpError()
-               << "expected zero value for global with target extension type";
+      return emitOpError() << "global with target extension type can only be "
+                              "initialized with zero-initializer";
     }
   }
 
@@ -2510,6 +2505,21 @@ Region *LLVMFuncOp::getCallableRegion() {
   return &getBody();
 }
 
+//===----------------------------------------------------------------------===//
+// ZeroOp.
+//===----------------------------------------------------------------------===//
+
+LogicalResult LLVM::ZeroOp::verify() {
+  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
+    if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
+      return emitOpError()
+             << "target extension type does not support zero-initializer";
+
+    return success();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp.
 //===----------------------------------------------------------------------===//
@@ -2554,16 +2564,7 @@ LogicalResult LLVM::ConstantOp::verify() {
     return success();
   }
   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
-    if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
-      return emitOpError()
-             << "target extension type does not support zero-initializer";
-    // Only a single, zero integer attribute (=zeroinitializer) is allowed for a
-    // global value with TargetExtType.
-    if (!isa<IntegerAttr>(getValue()) || !isZeroAttribute(getValue()))
-      return emitOpError()
-             << "only zero-initializer allowed for target extension types";
-
-    return success();
+    return emitOpError() << "target extension type does not support constants.";
   }
   if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
     return emitOpError()
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 35b2fcd3d3abe4e..0096a1dfee519e9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1077,16 +1077,9 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
         cast<LLVMTargetExtType>(convertType(constTargetNone->getType()));
     assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) &&
            "target extension type does not support zero-initialization");
-    // As the number of values needed for initialization is target-specific and
-    // opaque to the compiler, use a single i64 zero-valued attribute to
-    // represent the 'zeroinitializer', which is the only constant value allowed
-    // for target extension types (besides poison and undef).
-    // TODO: Replace with 'zeroinitializer' once there is a dedicated
-    // zeroinitializer operation in the LLVM dialect.
-    return builder
-        .create<LLVM::ConstantOp>(loc, targetExtType,
-                                  builder.getI64IntegerAttr(0))
-        .getRes();
+    // Create llvm.mlir.zero operation to represent zero-initialization of
+    // target extension type.
+    return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
   }
 
   StringRef error = "";
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index e653ec48d5679bf..daa53228995fcb8 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -246,10 +246,16 @@ llvm.mlir.global_dtors { dtors = [@dtor], priorities = [0 : i32]}
 // CHECK: llvm.mlir.global external @target_ext() {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0>
 llvm.mlir.global @target_ext() : !llvm.target<"spirv.Image", i32, 0>
 
-// CHECK: llvm.mlir.global external @target_ext_init(0 : i64) {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0>
-llvm.mlir.global @target_ext_init(0 : i64) : !llvm.target<"spirv.Image", i32, 0>
+// CHECK:       llvm.mlir.global external @target_ext_init() {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0>
+// CHECK-NEXT:    %0 = llvm.mlir.zero : !llvm.target<"spirv.Image", i32, 0>
+// CHECK-NEXT:    llvm.return %0 : !llvm.target<"spirv.Image", i32, 0>
+// CHECK-NEXT:  } 
+llvm.mlir.global @target_ext_init() : !llvm.target<"spirv.Image", i32, 0> {
+  %0 = llvm.mlir.zero : !llvm.target<"spirv.Image", i32, 0>
+  llvm.return %0 : !llvm.target<"spirv.Image", i32, 0>
+}
 
 // -----
 
-// expected-error @+1 {{expected zero value for global with target extension type}}
-llvm.mlir.global @target_fail(1 : i64) : !llvm.target<"spirv.Image", i32, 0>
+// expected-error @+1 {{global with target extension type can only be initialized with zero-initializer}}
+llvm.mlir.global @target_fail(0 : i64) : !llvm.target<"spirv.Image", i32, 0>
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index cf4697b17aa468a..02bbb019542632c 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1336,16 +1336,18 @@ func.func @invalid_target_ext_atomic(%arg0 : !llvm.ptr) {
 
 // -----
 
-func.func @invalid_target_ext_constant() {
+func.func @invalid_target_ext_constant_unsupported() {
   // expected-error at +1 {{target extension type does not support zero-initializer}}
-  %0 = llvm.mlir.constant(0 : index) : !llvm.target<"invalid_constant">
+  %0 = llvm.mlir.zero : !llvm.target<"invalid_constant">
+  llvm.return
 }
 
 // -----
 
 func.func @invalid_target_ext_constant() {
-  // expected-error at +1 {{only zero-initializer allowed for target extension types}}
-  %0 = llvm.mlir.constant(42 : index) : !llvm.target<"spirv.Event">
+  // expected-error at +1 {{target extension type does not support constants.}}
+  %0 = llvm.mlir.constant(0 : index) : !llvm.target<"spirv.Event">
+  llvm.return
 }
 
 // -----
diff --git a/mlir/test/Target/LLVMIR/Import/target-ext-type.ll b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll
index 62194cad9152c75..3c575b71038bf0d 100644
--- a/mlir/test/Target/LLVMIR/Import/target-ext-type.ll
+++ b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll
@@ -2,7 +2,7 @@
 
 ; CHECK-LABEL: llvm.mlir.global external @global() {addr_space = 0 : i32}
 ; CHECK-SAME:    !llvm.target<"spirv.DeviceEvent">
-; CHECK-NEXT:      %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+; CHECK-NEXT:      %0 = llvm.mlir.zero : !llvm.target<"spirv.DeviceEvent">
 ; CHECK-NEXT:      llvm.return %0 : !llvm.target<"spirv.DeviceEvent">
 @global = global target("spirv.DeviceEvent") zeroinitializer
 
@@ -45,7 +45,7 @@ define target("spirv.Event") @func2() {
 
 ; CHECK-LABEL: llvm.func @func3()
 define void @func3() {
-  ; CHECK-NEXT:    %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+  ; CHECK-NEXT:    %0 = llvm.mlir.zero : !llvm.target<"spirv.DeviceEvent">
   ; CHECK-NEXT:    %1 = llvm.freeze %0 : !llvm.target<"spirv.DeviceEvent">
   %val = freeze target("spirv.DeviceEvent") zeroinitializer
   ; CHECK-NEXT:    llvm.return
diff --git a/mlir/test/Target/LLVMIR/target-ext-type.mlir b/mlir/test/Target/LLVMIR/target-ext-type.mlir
index e7004b2699dc6b1..6b2d2ea3d4c2318 100644
--- a/mlir/test/Target/LLVMIR/target-ext-type.mlir
+++ b/mlir/test/Target/LLVMIR/target-ext-type.mlir
@@ -2,7 +2,7 @@
 
 // CHECK: @global = global target("spirv.DeviceEvent") zeroinitializer
 llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.target<"spirv.DeviceEvent"> {
-  %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+  %0 = llvm.mlir.zero : !llvm.target<"spirv.DeviceEvent">
   llvm.return %0 : !llvm.target<"spirv.DeviceEvent">
 }
 
@@ -22,7 +22,7 @@ llvm.func @func2() -> !llvm.target<"spirv.Event"> {
 // CHECK-NEXT:    %1 = freeze target("spirv.DeviceEvent") zeroinitializer
 // CHECK-NEXT:    ret void
 llvm.func @func3() {
-  %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+  %0 = llvm.mlir.zero : !llvm.target<"spirv.DeviceEvent">
   %1 = llvm.freeze %0 : !llvm.target<"spirv.DeviceEvent">
   llvm.return
 }



More information about the Mlir-commits mailing list