[Mlir-commits] [mlir] 4e39335 - [mlir] Add an AccessGroup attribute to load/store LLVM dialect ops and generate the access_group LLVM metadata.

Alex Zinenko llvmlistbot at llvm.org
Thu Mar 4 09:17:30 PST 2021


Author: Arpith C. Jacob
Date: 2021-03-04T18:17:23+01:00
New Revision: 4e393350c547edb8144592168c3b176646747a98

URL: https://github.com/llvm/llvm-project/commit/4e393350c547edb8144592168c3b176646747a98
DIFF: https://github.com/llvm/llvm-project/commit/4e393350c547edb8144592168c3b176646747a98.diff

LOG: [mlir] Add an AccessGroup attribute to load/store LLVM dialect ops and generate the access_group LLVM metadata.

This also includes LLVM dialect ops created from intrinsics.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D97944

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/llvmir.mlir
    mlir/test/mlir-tblgen/llvm-intrinsics.td
    mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index f0b4c69b6ae6..8c83dbc0c9d1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -35,6 +35,7 @@ def LLVM_Dialect : Dialect {
     static StringRef getLoopAttrName() { return "llvm.loop"; }
     static StringRef getParallelAccessAttrName() { return "parallel_access"; }
     static StringRef getLoopOptionsAttrName() { return "options"; }
+    static StringRef getAccessGroupsAttrName() { return "access_groups"; }
 
     /// Verifies if the given string is a well-formed data layout descriptor.
     /// Uses `reportError` to report errors.
@@ -247,7 +248,8 @@ def LLVM_IntrPatterns {
 // `llvm::Intrinsic` enum; one usually wants these to be related.
 class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                       list<int> overloadedResults, list<int> overloadedOperands,
-                      list<OpTrait> traits, int numResults>
+                      list<OpTrait> traits, int numResults,
+                      bit requiresAccessGroup = 0>
     : LLVM_OpBase<dialect, opName, traits>,
       Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
   string resultPattern = !if(!gt(numResults, 1),
@@ -264,19 +266,21 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
                          overloadedOperands>.lst), ", ") # [{
         });
     auto operands = moduleTranslation.lookupValues(opInst.getOperands());
-    }] # !if(!gt(numResults, 0), "$res = ", "")
-       # [{builder.CreateCall(fn, operands);
-  }];
+    }] # [{auto *inst = builder.CreateCall(fn, operands);
+    }] # !if(!gt(requiresAccessGroup, 0),
+      "moduleTranslation.setAccessGroupsMetadata(op, inst);",
+      "(void) inst;")
+    # !if(!gt(numResults, 0), "$res = inst;", "");
 }
 
 // Base class for LLVM intrinsic operations, should not be used directly. Places
 // the intrinsic into the LLVM dialect and prefixes its name with "intr.".
 class LLVM_IntrOp<string mnem, list<int> overloadedResults,
                   list<int> overloadedOperands, list<OpTrait> traits,
-                  int numResults>
+                  int numResults, bit requiresAccessGroup = 0>
     : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
                       overloadedResults, overloadedOperands, traits,
-                      numResults>;
+                      numResults, requiresAccessGroup>;
 
 // Base class for LLVM intrinsic operations returning no results. Places the
 // intrinsic into the LLVM dialect and prefixes its name with "intr.".

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 661f83c7c5f4..07583866621e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -287,6 +287,10 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
       inst->setMetadata(module->getMDKindID("nontemporal"), metadata);
     }
   }];
+
+  code setAccessGroupsMetadataCode = [{
+    moduleTranslation.setAccessGroupsMetadata(op, inst);
+  }];
 }
 
 // Memory-related operations.
@@ -326,12 +330,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]>,
 
 def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
   let arguments = (ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
+                   OptionalAttr<SymbolRefArrayAttr>:$access_groups,
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   let results = (outs LLVM_Type:$res);
   string llvmBuilder = [{
     auto *inst = builder.CreateLoad($addr, $volatile_);
-  }] # setAlignmentCode # setNonTemporalMetadataCode # [{
+  }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # [{
     $res = inst;
   }];
   let builders = [
@@ -346,16 +351,18 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
       CArg<"bool", "false">:$isNonTemporal)>];
   let parser = [{ return parseLoadOp(parser, result); }];
   let printer = [{ printLoadOp(p, *this); }];
+  let verifier = [{ return ::verify(*this);  }];
 }
 
 def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
   let arguments = (ins LLVM_LoadableType:$value,
                    LLVM_PointerTo<LLVM_LoadableType>:$addr,
+                   OptionalAttr<SymbolRefArrayAttr>:$access_groups,
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
   string llvmBuilder = [{
     auto *inst = builder.CreateStore($value, $addr, $volatile_);
-  }] # setAlignmentCode # setNonTemporalMetadataCode;
+  }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode;
   let builders = [
     OpBuilder<(ins "Value":$value, "Value":$addr,
       CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
@@ -363,6 +370,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
     ];
   let parser = [{ return parseStoreOp(parser, result); }];
   let printer = [{ printStoreOp(p, *this); }];
+  let verifier = [{ return ::verify(*this);  }];
 }
 
 // Casts.

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 748268575f86..e046ada3b004 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -128,6 +128,9 @@ class ModuleTranslation {
            "attempting to map loop options that was already mapped");
   }
 
+  // Sets LLVM metadata for memory operations that are in a parallel loop.
+  void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
+
   /// Converts the type from MLIR LLVM dialect to LLVM.
   llvm::Type *convertType(Type type);
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0538862b56e1..941792dc9c5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -404,6 +404,34 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verifyAccessGroups(Operation *op) {
+  if (Attribute attribute =
+          op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
+    // The attribute is already verified to be a symbol ref array attribute via
+    // a constraint in the operation definition.
+    for (SymbolRefAttr accessGroupRef :
+         attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
+      StringRef metadataName = accessGroupRef.getRootReference();
+      auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
+          op->getParentOp(), metadataName);
+      if (!metadataOp)
+        return op->emitOpError() << "expected '" << accessGroupRef
+                                 << "' to reference a metadata op";
+      StringRef accessGroupName = accessGroupRef.getLeafReference();
+      Operation *accessGroupOp =
+          SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
+      if (!accessGroupOp)
+        return op->emitOpError() << "expected '" << accessGroupRef
+                                 << "' to reference an access_group op";
+    }
+  }
+  return success();
+}
+
+static LogicalResult verify(LoadOp op) {
+  return verifyAccessGroups(op.getOperation());
+}
+
 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
                    Value addr, unsigned alignment, bool isVolatile,
                    bool isNonTemporal) {
@@ -462,6 +490,10 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
 // Builder, printer and parser for LLVM::StoreOp.
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verify(StoreOp op) {
+  return verifyAccessGroups(op.getOperation());
+}
+
 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
                     Value addr, unsigned alignment, bool isVolatile,
                     bool isNonTemporal) {

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3a03b278e264..891f30b95b66 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -656,6 +656,27 @@ LogicalResult ModuleTranslation::createAccessGroupMetadata() {
   return success();
 }
 
+void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
+                                                llvm::Instruction *inst) {
+  auto accessGroups =
+      op->getAttrOfType<ArrayAttr>(LLVMDialect::getAccessGroupsAttrName());
+  if (accessGroups && !accessGroups.empty()) {
+    llvm::Module *module = inst->getModule();
+    SmallVector<llvm::Metadata *> metadatas;
+    for (SymbolRefAttr accessGroupRef :
+         accessGroups.getAsRange<SymbolRefAttr>())
+      metadatas.push_back(getAccessGroup(*op, accessGroupRef));
+
+    llvm::MDNode *unionMD = nullptr;
+    if (metadatas.size() == 1)
+      unionMD = llvm::cast<llvm::MDNode>(metadatas.front());
+    else if (metadatas.size() >= 2)
+      unionMD = llvm::MDNode::get(module->getContext(), metadatas);
+
+    inst->setMetadata(module->getMDKindID("llvm.access.group"), unionMD);
+  }
+}
+
 llvm::Type *ModuleTranslation::convertType(Type type) {
   return typeTranslator.translateType(type);
 }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 6a45b1f67e71..e83706b88a9e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -796,3 +796,39 @@ module {
       llvm.return
   }
 }
+
+// -----
+
+module {
+  llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
+      // expected-error at below {{attribute 'access_groups' failed to satisfy constraint: symbol ref array attribute}}
+      %0 = llvm.load %arg0 { "access_groups" = "test" } : !llvm.ptr<i32>
+      llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
+      // expected-error at below {{expected '@func1' to reference a metadata op}}
+      %0 = llvm.load %arg0 { "access_groups" = [@func1] } : !llvm.ptr<i32>
+      llvm.return
+  }
+  llvm.func @func1() {
+    llvm.return
+  }
+}
+
+// -----
+
+module {
+  llvm.func @accessGroups(%arg0 : !llvm.ptr<i32>) {
+      // expected-error at below {{expected '@metadata' to reference an access_group op}}
+      %0 = llvm.load %arg0 { "access_groups" = [@metadata] } : !llvm.ptr<i32>
+      llvm.return
+  }
+  llvm.metadata @metadata {
+    llvm.return
+  }
+}

diff  --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 1109345231f2..85d0d16737b4 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1483,6 +1483,7 @@ module {
       llvm.cond_br %2, ^bb4, ^bb5 {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
     ^bb4:
       %3 = llvm.add %1, %arg2  : i32
+      // CHECK: = load i32, i32* %{{.*}} !llvm.access.group ![[ACCESS_GROUPS_NODE:[0-9]+]]
       %5 = llvm.load %4 { access_groups = [@metadata::@group1, @metadata::@group2] } : !llvm.ptr<i32>
       // CHECK: br label {{.*}} !llvm.loop ![[LOOP_NODE]]
       llvm.br ^bb3(%3 : i32) {llvm.loop = {parallel_access = [@metadata::@group1, @metadata::@group2], options = [#llvm.loopopt<disable_unroll = true>, #llvm.loopopt<disable_licm = true>, #llvm.loopopt<interleave_count = 1>]}}
@@ -1504,3 +1505,4 @@ module {
 // CHECK: ![[UNROLL_DISABLE_NODE]] = !{!"llvm.loop.unroll.disable", i1 true}
 // CHECK: ![[LICM_DISABLE_NODE]] = !{!"llvm.licm.disable", i1 true}
 // CHECK: ![[INTERLEAVE_NODE]] = !{!"llvm.loop.interleave.count", i32 1}
+// CHECK: ![[ACCESS_GROUPS_NODE]] = !{![[GROUP_NODE1]], ![[GROUP_NODE2]]}

diff  --git a/mlir/test/mlir-tblgen/llvm-intrinsics.td b/mlir/test/mlir-tblgen/llvm-intrinsics.td
index 511f062e8e4a..a6932b381284 100644
--- a/mlir/test/mlir-tblgen/llvm-intrinsics.td
+++ b/mlir/test/mlir-tblgen/llvm-intrinsics.td
@@ -23,11 +23,33 @@
 // It has no side effects.
 // CHECK: [NoSideEffect]
 // It has a result.
-// CHECK: 1>
+// CHECK: 1,
+// It does not require an access group.
+// CHECK: 0>
 // CHECK: Arguments<(ins LLVM_Type, LLVM_Type
 
 //---------------------------------------------------------------------------//
 
+// This checks that we can define an op that takes in an access group metadata.
+//
+// RUN: cat %S/../../../llvm/include/llvm/IR/Intrinsics.td \
+// RUN: | grep -v "llvm/IR/Intrinsics" \
+// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=ptrmask --llvmir-intrinsics-access-group-regexp=ptrmask \
+// RUN: | FileCheck --check-prefix=GROUPS %s
+
+// GROUPS-LABEL: def LLVM_ptrmask
+// GROUPS: LLVM_IntrOp<"ptrmask
+// It has no side effects.
+// GROUPS: [NoSideEffect]
+// It has a result.
+// GROUPS: 1,
+// It requires generation of an access group LLVM metadata.
+// GROUPS: 1>
+// It has an access group attribute.
+// GROUPS: OptionalAttr<SymbolRefArrayAttr>:$access_groups
+
+//---------------------------------------------------------------------------//
+
 // This checks that the ODS we produce can be consumed by MLIR tablegen. We only
 // make sure the entire process does not fail and produces some C++. The shape
 // of this C++ code is tested by ODS tests.

diff  --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
index 72554a1f0c24..dc76962b4d28 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
@@ -17,6 +17,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/MachineValueType.h"
 #include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Regex.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Main.h"
@@ -37,6 +38,12 @@ static llvm::cl::opt<std::string>
                                "are planning to emit"),
                 llvm::cl::init("LLVM_IntrOp"), llvm::cl::cat(IntrinsicGenCat));
 
+static llvm::cl::opt<std::string> accessGroupRegexp(
+    "llvmir-intrinsics-access-group-regexp",
+    llvm::cl::desc("Mark intrinsics that match the specified "
+                   "regexp as taking an access group metadata"),
+    llvm::cl::cat(IntrinsicGenCat));
+
 // Used to represent the indices of overloadable operands/results.
 using IndicesTy = llvm::SmallBitVector;
 
@@ -185,6 +192,10 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
 static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
   LLVMIntrinsic intr(record);
 
+  llvm::Regex accessGroupMatcher(accessGroupRegexp);
+  bool requiresAccessGroup =
+      !accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
+
   // Prepare strings for traits, if any.
   llvm::SmallVector<llvm::StringRef, 2> traits;
   if (intr.isCommutative())
@@ -195,6 +206,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
   // Prepare strings for operands.
   llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
                                                  "LLVM_Type");
+  if (requiresAccessGroup)
+    operands.push_back("OptionalAttr<SymbolRefArrayAttr>:$access_groups");
 
   // Emit the definition.
   os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass
@@ -204,7 +217,8 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
   printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
   os << ", ";
   printBracketedRange(traits, os);
-  os << ", " << intr.getNumResults() << ">, Arguments<(ins"
+  os << ", " << intr.getNumResults() << ", "
+     << (requiresAccessGroup ? "1" : "0") << ">, Arguments<(ins"
      << (operands.empty() ? "" : " ");
   llvm::interleaveComma(operands, os);
   os << ")>;\n\n";


        


More information about the Mlir-commits mailing list