[Mlir-commits] [mlir] beaffb0 - [mlir][transform] Decouple GPUDeviceMapping attribute from the GPU transfrom dialect code generator

Guray Ozen llvmlistbot at llvm.org
Tue Nov 15 09:16:39 PST 2022


Author: Guray Ozen
Date: 2022-11-15T18:16:32+01:00
New Revision: beaffb041c689deb30d9b06fb3a68a1a4bae48a4

URL: https://github.com/llvm/llvm-project/commit/beaffb041c689deb30d9b06fb3a68a1a4bae48a4
DIFF: https://github.com/llvm/llvm-project/commit/beaffb041c689deb30d9b06fb3a68a1a4bae48a4.diff

LOG: [mlir][transform] Decouple GPUDeviceMapping attribute from the GPU transfrom dialect code generator

`DeviceMappingAttrInterface` is implemented as unifiying mechanism for thread mapping. A code generator could use any attribute that implements this interface to lower `scf.foreach_thread` to device specific code. It is allowed to choose its own mapping and interpretation.

Currently, GPU transform dialect supports only `GPUThreadMapping` and `GPUBlockMapping`; however, other mappings should to be supported as well. This change addresses this issue. It decouples gpu transform dialect from the `GPUThreadMapping` and `GPUBlockMapping`. Now, they can work any other mapping.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D138020

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
    mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/test/Dialect/GPU/transform-gpu-failing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
index a93353dfafddf..bf67b3b6cd5c8 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
@@ -27,7 +27,8 @@ def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [
 }
 
 def GPUThreadMappingAttr 
-    : GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> {
+    : GPU_Attr<"GPUThreadMapping", "thread", [ 
+      DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ]> {
   let parameters = (ins
     EnumParameter<ThreadsEnum>:$thread
   );
@@ -47,7 +48,8 @@ def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [
   let cppNamespace = "::mlir::gpu";
 }
 
-def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] >  {
+def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ 
+  DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
     EnumParameter<BlocksEnum>:$block
   );

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index 94595153700b9..aaa0129f621c8 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -40,11 +40,11 @@ namespace gpu {
 /// which case, the union of the number of threads is computed and may result in
 /// predication. Dynamic, `scf.foreach_thread` trip counts are currently not
 /// supported. Dynamic block dim sizes are currently not supported.
-DiagnosedSilenceableFailure
-mapNestedForeachToThreadsImpl(RewriterBase &rewriter, Operation *target,
-                              const SmallVectorImpl<int64_t> &blockDim,
-                              bool syncAfterDistribute,
-                              llvm::Optional<TransformOpInterface> transformOp);
+DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
+    RewriterBase &rewriter, Operation *target,
+    const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
+    llvm::Optional<TransformOpInterface> transformOp,
+    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
 
 /// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is
 /// one-to-one and the induction variables of `scf.foreach_thread` are rewritten
@@ -56,7 +56,8 @@ DiagnosedSilenceableFailure mapForeachToBlocksImpl(
     function_ref<void(RewriterBase &, scf::ForeachThreadOp,
                       SmallVectorImpl<Value> &)>
         blockIdGenerator,
-    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp);
+    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
+    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);
 
 /// Finds the top level scf::ForeachThreadOp of given target.
 DiagnosedSilenceableFailure

diff  --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
index 2d2cafcfe45f4..6cdc8cb7fdb73 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -34,6 +34,14 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
     of the loops it contains to the GPU's parallelism units such as threads and 
     thread blocks.
   }];
+
+ let methods = [
+    InterfaceMethod<[{
+        Returns mapping as an integer from the attribute.
+      }],
+      "int64_t", "getMappingId", (ins)
+    >
+  ];
 }
 
 def DeviceMappingArrayAttr : 

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 0f25e51480737..099c7511376a0 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -33,6 +33,18 @@ using namespace mlir::gpu;
 
 #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// GPU Device Mapping Attributes
+//===----------------------------------------------------------------------===//
+
+int64_t GPUBlockMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getBlock());
+}
+
+int64_t GPUThreadMappingAttr::getMappingId() const {
+  return static_cast<int64_t>(getThread());
+}
+
 //===----------------------------------------------------------------------===//
 // MMAMatrixType
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 5fb90ac45ce13..bba49da342ac4 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
@@ -33,6 +34,24 @@ class SimpleRewriter : public PatternRewriter {
 };
 } // namespace
 
+/// Check if given mapping attributes are one of the desired attributes
+static DiagnosedSilenceableFailure checkAttributeType(
+    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes,
+    const Optional<ArrayAttr> &foreachMapping,
+    llvm::Optional<TransformOpInterface> transformOp) {
+  if (!foreachMapping.has_value())
+    return transformOp->emitSilenceableError() << "mapping must be present";
+
+  if (llvm::any_of(foreachMapping->getValue(),
+                   [&](DeviceMappingAttrInterface map) {
+                     return llvm::find(threadMappingAttributes, map) ==
+                            threadMappingAttributes.end();
+                   }))
+    return transformOp->emitDefiniteFailure()
+           << "mapping must be one of " << threadMappingAttributes;
+  return DiagnosedSilenceableFailure::success();
+}
+
 /// Determines if the size of the kernel configuration is supported by the GPU
 /// architecture being used. It presently makes use of CUDA limitations, however
 /// that aspect may be enhanced for other GPUs.
@@ -157,15 +176,13 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
     function_ref<void(RewriterBase &, scf::ForeachThreadOp,
                       SmallVectorImpl<Value> &)>
         blockIdGenerator,
-    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp) {
+    SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
+    const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
   // Step 0. Target-specific verifications. There is no good place to anchor
   // those right now: the ForeachThreadOp is target-independent and the
   // transform op does not apply to individual ForeachThreadOp.
-  MLIRContext *ctx = foreachThreadOp->getContext();
   Location loc = foreachThreadOp->getLoc();
-  Attribute bX = GPUBlockMappingAttr::get(ctx, Blocks::DimX);
-  Attribute bY = GPUBlockMappingAttr::get(ctx, Blocks::DimY);
-  Attribute bZ = GPUBlockMappingAttr::get(ctx, Blocks::DimZ);
+
   if (foreachThreadOp.getNumResults() > 0)
     return transformOp.emitSilenceableError()
            << "only bufferized scf.foreach_thread lowers to "
@@ -180,23 +197,15 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
     return transformOp.emitSilenceableError()
            << "unsupported dynamic griddim size";
   }
-  if (!foreachThreadOp.getMapping().has_value())
-    return transformOp.emitSilenceableError() << "mapping must be present";
   SmallVector<Attribute> blockMapping =
       llvm::to_vector(foreachThreadOp.getMapping()->getValue());
-  if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) {
-        return !map.isa<GPUBlockMappingAttr>();
-      })) {
-    return transformOp.emitSilenceableError()
-           << "mapping must be #gpu.block<x/y/z/>";
-  }
 
   // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
   SmallVector<Value> numBlocks =
       llvm::to_vector(foreachThreadOp.getNumThreads());
   // Ensure we have 3 block sizes, one for each id.
   Value one;
-  for (auto attr : {bX, bY, bZ}) {
+  for (auto attr : mappingAttributes) {
     if (std::find(blockMapping.begin(), blockMapping.end(), attr) ==
         blockMapping.end()) {
       blockMapping.push_back(attr);
@@ -205,10 +214,10 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
     }
   }
 
-  // Step 2. sort the values by the corresponding GPUBlockMappingAttr.
-  auto comparator = [](Attribute a, Attribute b) -> bool {
-    return static_cast<int64_t>(a.cast<GPUBlockMappingAttr>().getBlock()) <
-           static_cast<int64_t>(b.cast<GPUBlockMappingAttr>().getBlock());
+  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
+  auto comparator = [&](DeviceMappingAttrInterface a,
+                        DeviceMappingAttrInterface b) -> bool {
+    return a.getMappingId() < b.getMappingId();
   };
   SmallVector<Value> gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey(
       blockMapping, numBlocks, comparator);
@@ -222,8 +231,9 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
   BlockAndValueMapping bvm;
   for (auto [blockIdx, blockDim] :
        llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) {
-    bvm.map(blockIdx, blockOps[static_cast<int64_t>(
-                          blockDim.cast<GPUBlockMappingAttr>().getBlock())]);
+    bvm.map(blockIdx,
+            blockOps[static_cast<int64_t>(
+                blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
   }
 
   // Step 4. Move the body of foreachThreadOp.
@@ -331,9 +341,17 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
   }
 
   SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
-  diag = mlir::transform::gpu::mapForeachToBlocksImpl(
-      rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim,
-      transformOp);
+  SmallVector<DeviceMappingAttrInterface> blockMappingAttributes = {
+      GPUBlockMappingAttr::get(getContext(), Blocks::DimX),
+      GPUBlockMappingAttr::get(getContext(), Blocks::DimY),
+      GPUBlockMappingAttr::get(getContext(), Blocks::DimZ)};
+
+  diag = checkAttributeType(blockMappingAttributes,
+                            topLevelForeachThreadOp.getMapping(), transformOp);
+  if (diag.succeeded())
+    diag = mlir::transform::gpu::mapForeachToBlocksImpl(
+        rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim,
+        transformOp, blockMappingAttributes);
   if (diag.succeeded()) {
     diag = alterGpuLaunch(rewriter, gpuLaunch,
                           cast<TransformOpInterface>(getOperation()),
@@ -358,7 +376,8 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
 static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
     const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
-    llvm::Optional<TransformOpInterface> transformOp) {
+    llvm::Optional<TransformOpInterface> transformOp,
+    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
   // Step 0. Target-specific verifications. There is no good place to anchor
   // those right now: the ForeachThreadOp is target-independent and the
   // transform op does not apply to individual ForeachThreadOp.
@@ -369,11 +388,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     }
     return emitDefiniteFailure(foreachThreadOp, message);
   };
-  MLIRContext *ctx = foreachThreadOp->getContext();
   Location loc = foreachThreadOp->getLoc();
-  Attribute tX = GPUThreadMappingAttr::get(ctx, Threads::DimX);
-  Attribute tY = GPUThreadMappingAttr::get(ctx, Threads::DimY);
-  Attribute tZ = GPUThreadMappingAttr::get(ctx, Threads::DimZ);
   if (foreachThreadOp.getNumResults() > 0)
     return failureHelper(
         "only bufferized scf.foreach_thread lowers to gpu.thread_id");
@@ -389,12 +404,6 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     return failureHelper("mapping must be present");
   SmallVector<Attribute> threadMapping =
       llvm::to_vector(foreachThreadOp.getMapping()->getValue());
-  if (llvm::any_of(threadMapping, [](DeviceMappingAttrInterface map) {
-        return !map.isa<GPUThreadMappingAttr>();
-      })) {
-    return transformOp->emitSilenceableError()
-           << "mapping must be #gpu.thread<x/y/z/>";
-  }
 
   // Step 1. Complete the threadMapping to a full mapping (with 1s) if
   // necessary.
@@ -402,7 +411,7 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
       llvm::to_vector(foreachThreadOp.getNumThreads());
   // Ensure we have 3 block sizes, one for each id.
   Value one;
-  for (auto attr : {tX, tY, tZ}) {
+  for (auto attr : threadMappingAttributes) {
     if (std::find(threadMapping.begin(), threadMapping.end(), attr) ==
         threadMapping.end()) {
       threadMapping.push_back(attr);
@@ -411,10 +420,10 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     }
   }
 
-  // Step 2. sort the values by the corresponding GPUThreadMappingAttr.
-  auto comparator = [](Attribute a, Attribute b) -> bool {
-    return static_cast<int64_t>(a.cast<GPUThreadMappingAttr>().getThread()) <
-           static_cast<int64_t>(b.cast<GPUThreadMappingAttr>().getThread());
+  // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
+  auto comparator = [&](DeviceMappingAttrInterface a,
+                        DeviceMappingAttrInterface b) -> bool {
+    return a.getMappingId() < b.getMappingId();
   };
   SmallVector<Value> blockDimValues =
       scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads,
@@ -434,8 +443,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
   BlockAndValueMapping bvm;
   for (auto [blockIdx, blockDim] :
        llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
-    bvm.map(blockIdx, threadOps[static_cast<int64_t>(
-                          blockDim.cast<GPUThreadMappingAttr>().getThread())]);
+    bvm.map(
+        blockIdx,
+        threadOps[blockDim.cast<DeviceMappingAttrInterface>().getMappingId()]);
   }
 
   // Step 4. Maybe create conditionals to predicate the region.
@@ -501,12 +511,18 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
     RewriterBase &rewriter, Operation *target,
     const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
-    llvm::Optional<TransformOpInterface> transformOp) {
+    llvm::Optional<TransformOpInterface> transformOp,
+    const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
   target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
-    rewriter.setInsertionPoint(foreachThreadOp);
-    diag = rewriteOneForeachThreadToGpuThreads(
-        rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp);
+    diag = checkAttributeType(threadMappingAttributes,
+                              foreachThreadOp.getMapping(), transformOp);
+    if (diag.succeeded()) {
+      rewriter.setInsertionPoint(foreachThreadOp);
+      diag = rewriteOneForeachThreadToGpuThreads(
+          rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
+          threadMappingAttributes);
+    }
     return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
   });
   return diag;
@@ -536,11 +552,19 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
     return diag;
   }
 
-  SimpleRewriter rewriter(getContext());
+  MLIRContext *ctx = getContext();
+  SimpleRewriter rewriter(ctx);
   rewriter.setInsertionPoint(target);
 
+  SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
+      GPUThreadMappingAttr::get(ctx, Threads::DimX),
+      GPUThreadMappingAttr::get(ctx, Threads::DimY),
+      GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
+
   diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
-      rewriter, target, blockDim, getSyncAfterDistribute(), transformOp);
+      rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
+      threadMappingAttributes);
+
   if (diag.succeeded()) {
     diag =
         alterGpuLaunch(rewriter, gpuLaunch, transformOp, llvm::None, llvm::None,

diff  --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
index c45d9c058dd06..128ce7348b95d 100644
--- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
@@ -160,7 +160,7 @@ func.func @map_nested_foreach_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: t
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0
-  %foreach, %tiled = transform.structured.tile_to_foreach_thread_op %matmul num_threads [10, 20, 30]
+  %foreach, %tiled = transform.structured.tile_to_foreach_thread_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
   // expected-error @below {{only bufferized scf.foreach_thread lowers to gpu.thread_id}}    
   transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] }


        


More information about the Mlir-commits mailing list