[Mlir-commits] [mlir] [MLIR] Create GPU utils library & move distribution utils (PR #119264)

Petr Kurapov llvmlistbot at llvm.org
Fri Dec 13 00:49:57 PST 2024


https://github.com/kurapov-peter updated https://github.com/llvm/llvm-project/pull/119264

>From 83a48cea9dcf8186218b35952bae304138d02640 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Mon, 9 Dec 2024 20:25:50 +0000
Subject: [PATCH 1/9] [MLIR] Create GPU utils library & move distribution utils

---
 .../mlir/Conversion/GPUCommon/GPUCommonPass.h |   2 +-
 .../mlir/Dialect/GPU/Transforms/Passes.h      |   2 +-
 .../Dialect/GPU/Utils/DistributionUtils.h     |  57 +++++++
 .../{Transforms/Utils.h => Utils/GPUUtils.h}  |   0
 mlir/lib/Dialect/GPU/CMakeLists.txt           |   3 +-
 .../GPU/Transforms/AsyncRegionRewriter.cpp    |   2 +-
 .../GPU/Transforms/KernelOutlining.cpp        |   2 +-
 .../GPU/Transforms/SubgroupReduceLowering.cpp |   2 +-
 mlir/lib/Dialect/GPU/Utils/CMakeLists.txt     |  14 ++
 .../Dialect/GPU/Utils/DistributionUtils.cpp   | 149 ++++++++++++++++++
 .../GPU/{Transforms => Utils}/Utils.cpp       |   2 +-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../Vector/Transforms/VectorDistribute.cpp    | 139 +---------------
 13 files changed, 230 insertions(+), 145 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
 rename mlir/include/mlir/Dialect/GPU/{Transforms/Utils.h => Utils/GPUUtils.h} (100%)
 create mode 100644 mlir/lib/Dialect/GPU/Utils/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
 rename mlir/lib/Dialect/GPU/{Transforms => Utils}/Utils.cpp (96%)

diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index 5f40315a849094..094360e75ab617 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -8,7 +8,7 @@
 #ifndef MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_
 #define MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_
 
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 8eb711962583da..eb51d477e23f86 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -13,7 +13,7 @@
 #ifndef MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
 
-#include "Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
new file mode 100644
index 00000000000000..6efd2326971982
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -0,0 +1,57 @@
+//===- VectorDistributionUtils.h - Distribution Utilities -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
+#define MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBITIONUTILS_H_
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/PatternMatch.h"
+
+#include <utility>
+
+namespace mlir {
+namespace gpu {
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Helper to create a new WarpExecuteOnLane0Op with different signature.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+/// `indices` return the index of each new output.
+WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices);
+
+/// Helper to know if an op can be hoisted out of the region.
+bool canBeHoisted(Operation *op, function_ref<bool(Value)> definedOutside);
+
+/// Return a value yielded by `warpOp` which statifies the filter lamdba
+/// condition and is not dead.
+OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn);
+
+/// Delinearize the given `laneId` into multiple dimensions, where each
+/// dimension's size is determined by `originalShape` and `distributedShape`
+/// together. This function expects the total numbers of threads needed for
+/// distribution is equal to `warpSize`. Returns true and updates
+/// `delinearizedIds` if so.
+bool delinearizeLaneId(OpBuilder &builder, Location loc,
+                       ArrayRef<int64_t> originalShape,
+                       ArrayRef<int64_t> distributedShape, int64_t warpSize,
+                       Value laneId, SmallVectorImpl<Value> &delinearizedIds);
+
+} // namespace gpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/GPU/Utils/GPUUtils.h
similarity index 100%
rename from mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
rename to mlir/include/mlir/Dialect/GPU/Utils/GPUUtils.h
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index a59645480aba21..1026e9b509332a 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -40,7 +40,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
   Transforms/SubgroupReduceLowering.cpp
-  Transforms/Utils.cpp
   
   OBJECT
 
@@ -59,6 +58,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   MLIRDataLayoutInterfaces
   MLIRExecutionEngineUtils
   MLIRGPUDialect
+  MLIRGPUUtils
   MLIRIR
   MLIRIndexDialect
   MLIRLLVMDialect
@@ -76,3 +76,4 @@ add_mlir_dialect_library(MLIRGPUTransforms
 
 add_subdirectory(TransformOps)
 add_subdirectory(Pipelines)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index b2fa3a99c53fc3..41a5e39e55064e 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -16,7 +16,7 @@
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index ba0c80c50211e3..a6a36848b5635d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -18,7 +18,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 185f824351a230..43eff3eddcc491 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -13,7 +13,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
diff --git a/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt
new file mode 100644
index 00000000000000..69094c518a159e
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Utils/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRGPUUtils
+  Utils.cpp
+  DistributionUtils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU/Utils
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRAffineDialect
+  MLIRGPUDialect
+  MLIRSupport
+  MLIRIR
+  )
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
new file mode 100644
index 00000000000000..c6e8e03350bbce
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -0,0 +1,149 @@
+//===- DistributionUtils.cpp - Distribution tools for GPUOps --------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements distribution utility methods.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/Value.h"
+
+#include <numeric>
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+WarpExecuteOnLane0Op mlir::gpu::moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+WarpExecuteOnLane0Op mlir::gpu::moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+      rewriter, warpOp, yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}
+
+bool mlir::gpu::canBeHoisted(Operation *op,
+                             function_ref<bool(Value)> definedOutside) {
+  return llvm::all_of(op->getOperands(), definedOutside) &&
+         isMemoryEffectFree(op) && op->getNumRegions() == 0;
+}
+
+OpOperand *
+mlir::gpu::getWarpResult(WarpExecuteOnLane0Op warpOp,
+                         const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+bool mlir::gpu::delinearizeLaneId(OpBuilder &builder, Location loc,
+                                  ArrayRef<int64_t> originalShape,
+                                  ArrayRef<int64_t> distributedShape,
+                                  int64_t warpSize, Value laneId,
+                                  SmallVectorImpl<Value> &delinearizedIds) {
+  // If the original shape and the distributed shape is the same, we don't
+  // distribute at all--every thread is handling the whole. For such case, we
+  // should not rely on lane IDs later. So just return an empty lane ID vector.
+  if (originalShape == distributedShape) {
+    delinearizedIds.clear();
+    return true;
+  }
+
+  SmallVector<int64_t> sizes;
+  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
+    if (large % small != 0)
+      return false;
+    sizes.push_back(large / small);
+  }
+  if (std::accumulate(sizes.begin(), sizes.end(), 1,
+                      std::multiplies<int64_t>()) != warpSize)
+    return false;
+
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+
+  int64_t usedThreads = 1;
+
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  delinearizedIds.assign(sizes.size(), zero);
+
+  for (int i = sizes.size() - 1; i >= 0; --i) {
+    usedThreads *= sizes[i];
+    if (usedThreads == warpSize) {
+      // We've used up all available threads. Don't need to perform modulo
+      // anymore. And we can stop the calculation for further dimensions.
+      delinearizedIds[i] = laneId;
+      break;
+    }
+    delinearizedIds[i] =
+        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
+    laneId = affine::makeComposedAffineApply(
+        builder, loc, s0.floorDiv(usedThreads), {laneId});
+  }
+  return true;
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Utils/Utils.cpp
similarity index 96%
rename from mlir/lib/Dialect/GPU/Transforms/Utils.cpp
rename to mlir/lib/Dialect/GPU/Utils/Utils.cpp
index e91aa18128c7b9..1f09875b3e2732 100644
--- a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/Utils.cpp
@@ -10,7 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "llvm/Support/ErrorHandling.h"
 
 namespace mlir::gpu {
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 9a3bd5d4593d63..8ca5cb6c6dfabc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRArithDialect
   MLIRDialectUtils
   MLIRGPUDialect
+  MLIRGPUUtils
   MLIRIR
   MLIRLinalgDialect
   MLIRMemRefDialect
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3e142598369951..d080b0b0bd44bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -18,7 +19,6 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/FormatVariadic.h"
-#include <numeric>
 #include <utility>
 
 using namespace mlir;
@@ -162,92 +162,6 @@ struct DistributedLoadStoreHelper {
 
 } // namespace
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
-/// Helper to know if an op can be hoisted out of the region.
-static bool canBeHoisted(Operation *op,
-                         function_ref<bool(Value)> definedOutside) {
-  return llvm::all_of(op->getOperands(), definedOutside) &&
-         isMemoryEffectFree(op) && op->getNumRegions() == 0;
-}
-
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                                const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
 // Clones `op` into a new operation that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
@@ -770,57 +684,6 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
-/// Delinearize the given `laneId` into multiple dimensions, where each
-/// dimension's size is determined by `originalShape` and `distributedShape`
-/// together. This function expects the total numbers of threads needed for
-/// distribution is equal to `warpSize`. Returns true and updates
-/// `delinearizedIds` if so.
-bool delinearizeLaneId(OpBuilder &builder, Location loc,
-                       ArrayRef<int64_t> originalShape,
-                       ArrayRef<int64_t> distributedShape, int64_t warpSize,
-                       Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
-  // If the original shape and the distributed shape is the same, we don't
-  // distribute at all--every thread is handling the whole. For such case, we
-  // should not rely on lane IDs later. So just return an empty lane ID vector.
-  if (originalShape == distributedShape) {
-    delinearizedIds.clear();
-    return true;
-  }
-
-  SmallVector<int64_t> sizes;
-  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
-    if (large % small != 0)
-      return false;
-    sizes.push_back(large / small);
-  }
-  if (std::accumulate(sizes.begin(), sizes.end(), 1,
-                      std::multiplies<int64_t>()) != warpSize)
-    return false;
-
-  AffineExpr s0, s1;
-  bindSymbols(builder.getContext(), s0, s1);
-
-  int64_t usedThreads = 1;
-
-  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  delinearizedIds.assign(sizes.size(), zero);
-
-  for (int i = sizes.size() - 1; i >= 0; --i) {
-    usedThreads *= sizes[i];
-    if (usedThreads == warpSize) {
-      // We've used up all available threads. Don't need to perform modulo
-      // anymore. And we can stop the calculation for further dimensions.
-      delinearizedIds[i] = laneId;
-      break;
-    }
-    delinearizedIds[i] =
-        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
-    laneId = affine::makeComposedAffineApply(
-        builder, loc, s0.floorDiv(usedThreads), {laneId});
-  }
-  return true;
-}
-
 /// Sink out transfer_read op feeding into a warp op yield.
 /// ```
 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {

>From 3c7d9aacbddac8720591c8faa17c322e7a67ead5 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Tue, 10 Dec 2024 10:26:40 +0000
Subject: [PATCH 2/9] fix formatting

---
 mlir/include/mlir/Dialect/GPU/Transforms/Passes.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index eb51d477e23f86..aaef91f31ab9c0 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -13,8 +13,8 @@
 #ifndef MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_GPU_TRANSFORMS_PASSES_H_
 
-#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>

>From e5b53de1e2e96c12c8f0d8457264b6c683c650ae Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 12 Dec 2024 10:32:34 +0000
Subject: [PATCH 3/9] hide helper methods behind a pattern interface

---
 .../Dialect/GPU/Utils/DistributionUtils.h     | 193 +++++++++++++++---
 .../Dialect/GPU/Utils/DistributionUtils.cpp   | 128 ------------
 .../Vector/Transforms/VectorDistribute.cpp    | 158 +++++++-------
 3 files changed, 248 insertions(+), 231 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
index 6efd2326971982..c97c9592b64fbf 100644
--- a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -9,47 +9,182 @@
 #ifndef MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
 #define MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBITIONUTILS_H_
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
 
+#include <numeric>
 #include <utility>
 
 namespace mlir {
 namespace gpu {
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                         const std::function<bool(Operation *)> &fn);
+/// Move scalar operations with no dependency on the warp op outside of the
+/// region.
+void moveScalarUniformCode(gpu::WarpExecuteOnLane0Op op);
 
-/// Helper to create a new WarpExecuteOnLane0Op with different signature.
-WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+template <typename T>
+struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  virtual LogicalResult
+  matchAndRewrite(T op, PatternRewriter &rewriter) const override = 0;
+
+protected:
+  /// Return a value yielded by `warpOp` which statifies the filter lamdba
+  /// condition and is not dead.
+  static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                                  const std::function<bool(Operation *)> &fn);
+
+  /// Helper to create a new WarpExecuteOnLane0Op with different signature.
+  static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+      RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+      ValueRange newYieldedValues, TypeRange newReturnTypes);
+
+  /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
+  /// `indices` return the index of each new output.
+  static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+      RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+      ValueRange newYieldedValues, TypeRange newReturnTypes,
+      llvm::SmallVector<size_t> &indices);
+
+  /// Delinearize the given `laneId` into multiple dimensions, where each
+  /// dimension's size is determined by `originalShape` and `distributedShape`
+  /// together. This function expects the total numbers of threads needed for
+  /// distribution is equal to `warpSize`. Returns true and updates
+  /// `delinearizedIds` if so.
+  static bool delinearizeLaneId(OpBuilder &builder, Location loc,
+                                ArrayRef<int64_t> originalShape,
+                                ArrayRef<int64_t> distributedShape,
+                                int64_t warpSize, Value laneId,
+                                SmallVectorImpl<Value> &delinearizedIds);
+};
+
+template <typename T>
+WarpExecuteOnLane0Op
+WarpDistributionPattern<T>::moveRegionToNewWarpOpAndReplaceReturns(
     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes);
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
 
-/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
-/// `indices` return the index of each new output.
-WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+template <typename T>
+WarpExecuteOnLane0Op
+WarpDistributionPattern<T>::moveRegionToNewWarpOpAndAppendReturns(
     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices);
-
-/// Helper to know if an op can be hoisted out of the region.
-bool canBeHoisted(Operation *op, function_ref<bool(Value)> definedOutside);
-
-/// Return a value yielded by `warpOp` which statifies the filter lamdba
-/// condition and is not dead.
-OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                         const std::function<bool(Operation *)> &fn);
-
-/// Delinearize the given `laneId` into multiple dimensions, where each
-/// dimension's size is determined by `originalShape` and `distributedShape`
-/// together. This function expects the total numbers of threads needed for
-/// distribution is equal to `warpSize`. Returns true and updates
-/// `delinearizedIds` if so.
-bool delinearizeLaneId(OpBuilder &builder, Location loc,
-                       ArrayRef<int64_t> originalShape,
-                       ArrayRef<int64_t> distributedShape, int64_t warpSize,
-                       Value laneId, SmallVectorImpl<Value> &delinearizedIds);
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+      rewriter, warpOp, yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}
+
+template <typename T>
+OpOperand *WarpDistributionPattern<T>::getWarpResult(
+    WarpExecuteOnLane0Op warpOp, const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+template <typename T>
+bool WarpDistributionPattern<T>::delinearizeLaneId(
+    OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape,
+    ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId,
+    SmallVectorImpl<Value> &delinearizedIds) {
+  // If the original shape and the distributed shape is the same, we don't
+  // distribute at all--every thread is handling the whole. For such case, we
+  // should not rely on lane IDs later. So just return an empty lane ID vector.
+  if (originalShape == distributedShape) {
+    delinearizedIds.clear();
+    return true;
+  }
+
+  SmallVector<int64_t> sizes;
+  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
+    if (large % small != 0)
+      return false;
+    sizes.push_back(large / small);
+  }
+  if (std::accumulate(sizes.begin(), sizes.end(), 1,
+                      std::multiplies<int64_t>()) != warpSize)
+    return false;
+
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+
+  int64_t usedThreads = 1;
+
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  delinearizedIds.assign(sizes.size(), zero);
+
+  for (int i = sizes.size() - 1; i >= 0; --i) {
+    usedThreads *= sizes[i];
+    if (usedThreads == warpSize) {
+      // We've used up all available threads. Don't need to perform modulo
+      // anymore. And we can stop the calculation for further dimensions.
+      delinearizedIds[i] = laneId;
+      break;
+    }
+    delinearizedIds[i] =
+        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
+    laneId = affine::makeComposedAffineApply(
+        builder, loc, s0.floorDiv(usedThreads), {laneId});
+  }
+  return true;
+}
 
 } // namespace gpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index c6e8e03350bbce..664918cc3c9131 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -19,131 +19,3 @@
 
 using namespace mlir;
 using namespace mlir::gpu;
-
-WarpExecuteOnLane0Op mlir::gpu::moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-WarpExecuteOnLane0Op mlir::gpu::moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
-bool mlir::gpu::canBeHoisted(Operation *op,
-                             function_ref<bool(Value)> definedOutside) {
-  return llvm::all_of(op->getOperands(), definedOutside) &&
-         isMemoryEffectFree(op) && op->getNumRegions() == 0;
-}
-
-OpOperand *
-mlir::gpu::getWarpResult(WarpExecuteOnLane0Op warpOp,
-                         const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
-bool mlir::gpu::delinearizeLaneId(OpBuilder &builder, Location loc,
-                                  ArrayRef<int64_t> originalShape,
-                                  ArrayRef<int64_t> distributedShape,
-                                  int64_t warpSize, Value laneId,
-                                  SmallVectorImpl<Value> &delinearizedIds) {
-  // If the original shape and the distributed shape is the same, we don't
-  // distribute at all--every thread is handling the whole. For such case, we
-  // should not rely on lane IDs later. So just return an empty lane ID vector.
-  if (originalShape == distributedShape) {
-    delinearizedIds.clear();
-    return true;
-  }
-
-  SmallVector<int64_t> sizes;
-  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
-    if (large % small != 0)
-      return false;
-    sizes.push_back(large / small);
-  }
-  if (std::accumulate(sizes.begin(), sizes.end(), 1,
-                      std::multiplies<int64_t>()) != warpSize)
-    return false;
-
-  AffineExpr s0, s1;
-  bindSymbols(builder.getContext(), s0, s1);
-
-  int64_t usedThreads = 1;
-
-  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  delinearizedIds.assign(sizes.size(), zero);
-
-  for (int i = sizes.size() - 1; i >= 0; --i) {
-    usedThreads *= sizes[i];
-    if (usedThreads == warpSize) {
-      // We've used up all available threads. Don't need to perform modulo
-      // anymore. And we can stop the calculation for further dimensions.
-      delinearizedIds[i] = laneId;
-      break;
-    }
-    delinearizedIds[i] =
-        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
-    laneId = affine::makeComposedAffineApply(
-        builder, loc, s0.floorDiv(usedThreads), {laneId});
-  }
-  return true;
-}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index d080b0b0bd44bd..f30574315ee901 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -203,11 +203,12 @@ namespace {
 ///
 /// All this assumes the vector distribution occurs along the most minor
 /// distributed vector dimension.
-struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
+struct WarpOpToScfIfPattern
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
   WarpOpToScfIfPattern(MLIRContext *context,
                        const WarpExecuteOnLane0LoweringOptions &options,
                        PatternBenefit benefit = 1)
-      : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
+      : WarpDistributionPattern<WarpExecuteOnLane0Op>(context, benefit),
         options(options) {}
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
@@ -316,39 +317,6 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
   const WarpExecuteOnLane0LoweringOptions &options;
 };
 
-/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
-/// op with the proper return type.
-/// The new write op is updated to write the result of the new warp execute op.
-/// The old `writeOp` is deleted.
-static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
-                                            WarpExecuteOnLane0Op warpOp,
-                                            vector::TransferWriteOp writeOp,
-                                            VectorType targetType,
-                                            VectorType maybeMaskType) {
-  assert(writeOp->getParentOp() == warpOp &&
-         "write must be nested immediately under warp");
-  OpBuilder::InsertionGuard g(rewriter);
-  SmallVector<size_t> newRetIndices;
-  WarpExecuteOnLane0Op newWarpOp;
-  if (maybeMaskType) {
-    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
-        TypeRange{targetType, maybeMaskType}, newRetIndices);
-  } else {
-    newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, ValueRange{{writeOp.getVector()}},
-        TypeRange{targetType}, newRetIndices);
-  }
-  rewriter.setInsertionPointAfter(newWarpOp);
-  auto newWriteOp =
-      cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
-  rewriter.eraseOp(writeOp);
-  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
-  if (maybeMaskType)
-    newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
-  return newWriteOp;
-}
-
 /// Return the distributed vector type based on the original type and the
 /// distribution map. The map is expected to have a dimension equal to the
 /// original type rank and should be a projection where the results are the
@@ -401,10 +369,11 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 ///   gpu.yield %v : vector<32xf32>
 /// }
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
-struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
+struct WarpOpTransferWrite
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
                       unsigned maxNumElementsToExtract, PatternBenefit b = 1)
-      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
+      : WarpDistributionPattern<WarpExecuteOnLane0Op>(ctx, b),
         distributionMapFn(std::move(fn)),
         maxNumElementsToExtract(maxNumElementsToExtract) {}
 
@@ -568,6 +537,38 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 
 private:
+  /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
+  /// execute op with the proper return type. The new write op is updated to
+  /// write the result of the new warp execute op. The old `writeOp` is deleted.
+  static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
+                                              WarpExecuteOnLane0Op warpOp,
+                                              vector::TransferWriteOp writeOp,
+                                              VectorType targetType,
+                                              VectorType maybeMaskType) {
+    assert(writeOp->getParentOp() == warpOp &&
+           "write must be nested immediately under warp");
+    OpBuilder::InsertionGuard g(rewriter);
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp;
+    if (maybeMaskType) {
+      newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
+          TypeRange{targetType, maybeMaskType}, newRetIndices);
+    } else {
+      newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, ValueRange{{writeOp.getVector()}},
+          TypeRange{targetType}, newRetIndices);
+    }
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newWriteOp =
+        cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
+    rewriter.eraseOp(writeOp);
+    newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
+    if (maybeMaskType)
+      newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
+    return newWriteOp;
+  }
+
   DistributionMapFn distributionMapFn;
   unsigned maxNumElementsToExtract = 1;
 };
@@ -590,8 +591,9 @@ struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
 ///   vector<32xf32>
 /// }
 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
-struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpElementwise
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
@@ -656,8 +658,8 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
 ///   ...
 /// }
 /// %0 = arith.constant dense<2.0> : vector<1xf32>
-struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpConstant : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
@@ -702,8 +704,9 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
 ///   vector<32xf32> gpu.yield %2 : vector<32xf32>
 /// }
 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
-struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpTransferRead
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     // Try to find a distributable yielded read. Note that this pattern can
@@ -814,8 +817,8 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 /// Remove any result that has no use along with the matching yieldOp operand.
 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
-struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpDeadResult : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Type> newResultTypes;
@@ -875,8 +878,9 @@ struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 // If an operand is directly yielded out of the region we can forward it
 // directly and it doesn't need to go through the region.
-struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpForwardOperand
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Type> resultTypes;
@@ -919,8 +923,8 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
-struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpBroadcast : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -956,8 +960,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
-struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpShapeCast : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1015,8 +1019,8 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
 /// %cmp = arith.cmpi ult, %laneid, %0
 /// %ub = arith.select %cmp, %c0, %c1
 /// %1 = vector.create_mask %ub : vector<1xi1>
-struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern::OpRewritePattern;
+struct WarpOpCreateMask : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
@@ -1081,8 +1085,8 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
-struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+struct WarpOpExtract : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1161,10 +1165,11 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 /// Pattern to move out vector.extract with a scalar result.
 /// Only supports 1-D and 0-D sources for now.
-struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
+struct WarpOpExtractScalar
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
   WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
                       PatternBenefit b = 1)
-      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
+      : WarpDistributionPattern<WarpExecuteOnLane0Op>(ctx, b),
         warpShuffleFromIdxFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1260,9 +1265,9 @@ struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
 };
 
 /// Pattern to convert vector.extractelement to vector.extract.
-struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
-      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
+struct WarpOpExtractElement
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1283,9 +1288,9 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
 /// Pattern to move out vector.insert with a scalar input.
 /// Only supports 1-D and 0-D destinations for now.
-struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
-
+struct WarpOpInsertScalar
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
@@ -1376,9 +1381,8 @@ struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
-struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
-
+struct WarpOpInsert : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
@@ -1490,9 +1494,9 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
-struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
-
+struct WarpOpInsertElement
+    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+  using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1543,12 +1547,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
 ///     scf.yield %iw : vector<4xf32>
 ///  }
 /// ```
-struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
+struct WarpOpScfForOp : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 
   WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
-      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
+      : WarpDistributionPattern<WarpExecuteOnLane0Op>(ctx, b),
         distributionMapFn(std::move(fn)) {}
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
@@ -1687,11 +1690,11 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
 /// %r = ("warp.reduction %a")
 /// ```
-struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
+struct WarpOpReduction : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
   WarpOpReduction(MLIRContext *context,
                   DistributedReductionFn distributedReductionFn,
                   PatternBenefit benefit = 1)
-      : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
+      : WarpDistributionPattern<WarpExecuteOnLane0Op>(context, benefit),
         distributedReductionFn(std::move(distributedReductionFn)) {}
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
@@ -1790,6 +1793,13 @@ void mlir::vector::populateDistributeReduction(
                                 benefit);
 }
 
+/// Helper to know if an op can be hoisted out of the region.
+static bool canBeHoisted(Operation *op,
+                         function_ref<bool(Value)> definedOutside) {
+  return llvm::all_of(op->getOperands(), definedOutside) &&
+         isMemoryEffectFree(op) && op->getNumRegions() == 0;
+}
+
 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
   Block *body = warpOp.getBody();
 

>From ec5015fdf68f225645b122ede5b3758d4451fa83 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 12 Dec 2024 10:40:22 +0000
Subject: [PATCH 4/9] get rid of template parameter

---
 .../Dialect/GPU/Utils/DistributionUtils.h     | 130 +-----------------
 .../Dialect/GPU/Utils/DistributionUtils.cpp   | 122 ++++++++++++++++
 .../Vector/Transforms/VectorDistribute.cpp    |  59 ++++----
 3 files changed, 147 insertions(+), 164 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
index c97c9592b64fbf..c62f01f6435dda 100644
--- a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -24,11 +24,11 @@ namespace gpu {
 /// region.
 void moveScalarUniformCode(gpu::WarpExecuteOnLane0Op op);
 
-template <typename T>
 struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   virtual LogicalResult
-  matchAndRewrite(T op, PatternRewriter &rewriter) const override = 0;
+  matchAndRewrite(WarpExecuteOnLane0Op op,
+                  PatternRewriter &rewriter) const override = 0;
 
 protected:
   /// Return a value yielded by `warpOp` which statifies the filter lamdba
@@ -60,132 +60,6 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
                                 SmallVectorImpl<Value> &delinearizedIds);
 };
 
-template <typename T>
-WarpExecuteOnLane0Op
-WarpDistributionPattern<T>::moveRegionToNewWarpOpAndReplaceReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
-  // Create a new op before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(warpOp);
-  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
-      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
-      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
-
-  Region &opBody = warpOp.getBodyRegion();
-  Region &newOpBody = newWarpOp.getBodyRegion();
-  Block &newOpFirstBlock = newOpBody.front();
-  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
-  rewriter.eraseBlock(&newOpFirstBlock);
-  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
-         "expected WarpOp with single block");
-
-  auto yield =
-      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
-
-  rewriter.modifyOpInPlace(
-      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
-  return newWarpOp;
-}
-
-template <typename T>
-WarpExecuteOnLane0Op
-WarpDistributionPattern<T>::moveRegionToNewWarpOpAndAppendReturns(
-    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
-  SmallVector<Type> types(warpOp.getResultTypes().begin(),
-                          warpOp.getResultTypes().end());
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
-                                              yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
-      indices.push_back(yieldValues.size() - 1);
-    } else {
-      // If the value already exit the region don't create a new output.
-      for (auto [idx, yieldOperand] :
-           llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
-          indices.push_back(idx);
-          break;
-        }
-      }
-    }
-  }
-  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
-  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
-      rewriter, warpOp, yieldValues.getArrayRef(), types);
-  rewriter.replaceOp(warpOp,
-                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
-  return newWarpOp;
-}
-
-template <typename T>
-OpOperand *WarpDistributionPattern<T>::getWarpResult(
-    WarpExecuteOnLane0Op warpOp, const std::function<bool(Operation *)> &fn) {
-  auto yield = cast<gpu::YieldOp>(
-      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
-  for (OpOperand &yieldOperand : yield->getOpOperands()) {
-    Value yieldValues = yieldOperand.get();
-    Operation *definedOp = yieldValues.getDefiningOp();
-    if (definedOp && fn(definedOp)) {
-      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
-        return &yieldOperand;
-    }
-  }
-  return {};
-}
-
-template <typename T>
-bool WarpDistributionPattern<T>::delinearizeLaneId(
-    OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape,
-    ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId,
-    SmallVectorImpl<Value> &delinearizedIds) {
-  // If the original shape and the distributed shape is the same, we don't
-  // distribute at all--every thread is handling the whole. For such case, we
-  // should not rely on lane IDs later. So just return an empty lane ID vector.
-  if (originalShape == distributedShape) {
-    delinearizedIds.clear();
-    return true;
-  }
-
-  SmallVector<int64_t> sizes;
-  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
-    if (large % small != 0)
-      return false;
-    sizes.push_back(large / small);
-  }
-  if (std::accumulate(sizes.begin(), sizes.end(), 1,
-                      std::multiplies<int64_t>()) != warpSize)
-    return false;
-
-  AffineExpr s0, s1;
-  bindSymbols(builder.getContext(), s0, s1);
-
-  int64_t usedThreads = 1;
-
-  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  delinearizedIds.assign(sizes.size(), zero);
-
-  for (int i = sizes.size() - 1; i >= 0; --i) {
-    usedThreads *= sizes[i];
-    if (usedThreads == warpSize) {
-      // We've used up all available threads. Don't need to perform modulo
-      // anymore. And we can stop the calculation for further dimensions.
-      delinearizedIds[i] = laneId;
-      break;
-    }
-    delinearizedIds[i] =
-        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
-    laneId = affine::makeComposedAffineApply(
-        builder, loc, s0.floorDiv(usedThreads), {laneId});
-  }
-  return true;
-}
-
 } // namespace gpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 664918cc3c9131..857c1e84f303f8 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -19,3 +19,125 @@
 
 using namespace mlir;
 using namespace mlir::gpu;
+
+WarpExecuteOnLane0Op
+WarpDistributionPattern::moveRegionToNewWarpOpAndReplaceReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+  // Create a new op before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(warpOp);
+  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
+      warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
+      warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
+
+  Region &opBody = warpOp.getBodyRegion();
+  Region &newOpBody = newWarpOp.getBodyRegion();
+  Block &newOpFirstBlock = newOpBody.front();
+  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+  rewriter.eraseBlock(&newOpFirstBlock);
+  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+         "expected WarpOp with single block");
+
+  auto yield =
+      cast<gpu::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
+
+  rewriter.modifyOpInPlace(
+      yield, [&]() { yield.getValuesMutable().assign(newYieldedValues); });
+  return newWarpOp;
+}
+
+WarpExecuteOnLane0Op
+WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
+    RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
+    ValueRange newYieldedValues, TypeRange newReturnTypes,
+    llvm::SmallVector<size_t> &indices) {
+  SmallVector<Type> types(warpOp.getResultTypes().begin(),
+                          warpOp.getResultTypes().end());
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
+                                              yield.getOperands().end());
+  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(std::get<0>(newRet))) {
+      types.push_back(std::get<1>(newRet));
+      indices.push_back(yieldValues.size() - 1);
+    } else {
+      // If the value already exit the region don't create a new output.
+      for (auto [idx, yieldOperand] :
+           llvm::enumerate(yieldValues.getArrayRef())) {
+        if (yieldOperand == std::get<0>(newRet)) {
+          indices.push_back(idx);
+          break;
+        }
+      }
+    }
+  }
+  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
+  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+      rewriter, warpOp, yieldValues.getArrayRef(), types);
+  rewriter.replaceOp(warpOp,
+                     newWarpOp.getResults().take_front(warpOp.getNumResults()));
+  return newWarpOp;
+}
+
+OpOperand *WarpDistributionPattern::getWarpResult(
+    WarpExecuteOnLane0Op warpOp, const std::function<bool(Operation *)> &fn) {
+  auto yield = cast<gpu::YieldOp>(
+      warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+  for (OpOperand &yieldOperand : yield->getOpOperands()) {
+    Value yieldValues = yieldOperand.get();
+    Operation *definedOp = yieldValues.getDefiningOp();
+    if (definedOp && fn(definedOp)) {
+      if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
+        return &yieldOperand;
+    }
+  }
+  return {};
+}
+
+bool WarpDistributionPattern::delinearizeLaneId(
+    OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape,
+    ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId,
+    SmallVectorImpl<Value> &delinearizedIds) {
+  // If the original shape and the distributed shape is the same, we don't
+  // distribute at all--every thread is handling the whole. For such case, we
+  // should not rely on lane IDs later. So just return an empty lane ID vector.
+  if (originalShape == distributedShape) {
+    delinearizedIds.clear();
+    return true;
+  }
+
+  SmallVector<int64_t> sizes;
+  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
+    if (large % small != 0)
+      return false;
+    sizes.push_back(large / small);
+  }
+  if (std::accumulate(sizes.begin(), sizes.end(), 1,
+                      std::multiplies<int64_t>()) != warpSize)
+    return false;
+
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+
+  int64_t usedThreads = 1;
+
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  delinearizedIds.assign(sizes.size(), zero);
+
+  for (int i = sizes.size() - 1; i >= 0; --i) {
+    usedThreads *= sizes[i];
+    if (usedThreads == warpSize) {
+      // We've used up all available threads. Don't need to perform modulo
+      // anymore. And we can stop the calculation for further dimensions.
+      delinearizedIds[i] = laneId;
+      break;
+    }
+    delinearizedIds[i] =
+        affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
+    laneId = affine::makeComposedAffineApply(
+        builder, loc, s0.floorDiv(usedThreads), {laneId});
+  }
+  return true;
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index f30574315ee901..068a685c29c65c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -203,13 +203,11 @@ namespace {
 ///
 /// All this assumes the vector distribution occurs along the most minor
 /// distributed vector dimension.
-struct WarpOpToScfIfPattern
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpToScfIfPattern : public WarpDistributionPattern {
   WarpOpToScfIfPattern(MLIRContext *context,
                        const WarpExecuteOnLane0LoweringOptions &options,
                        PatternBenefit benefit = 1)
-      : WarpDistributionPattern<WarpExecuteOnLane0Op>(context, benefit),
-        options(options) {}
+      : WarpDistributionPattern(context, benefit), options(options) {}
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -369,12 +367,10 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
 ///   gpu.yield %v : vector<32xf32>
 /// }
 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
-struct WarpOpTransferWrite
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpTransferWrite : public WarpDistributionPattern {
   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
                       unsigned maxNumElementsToExtract, PatternBenefit b = 1)
-      : WarpDistributionPattern<WarpExecuteOnLane0Op>(ctx, b),
-        distributionMapFn(std::move(fn)),
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
         maxNumElementsToExtract(maxNumElementsToExtract) {}
 
   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
@@ -591,8 +587,7 @@ struct WarpOpTransferWrite
 ///   vector<32xf32>
 /// }
 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
-struct WarpOpElementwise
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpElementwise : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -658,7 +653,7 @@ struct WarpOpElementwise
 ///   ...
 /// }
 /// %0 = arith.constant dense<2.0> : vector<1xf32>
-struct WarpOpConstant : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpConstant : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -704,8 +699,7 @@ struct WarpOpConstant : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 ///   vector<32xf32> gpu.yield %2 : vector<32xf32>
 /// }
 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
-struct WarpOpTransferRead
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpTransferRead : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -817,7 +811,7 @@ struct WarpOpTransferRead
 
 /// Remove any result that has no use along with the matching yieldOp operand.
 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
-struct WarpOpDeadResult : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpDeadResult : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -878,8 +872,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 
 // If an operand is directly yielded out of the region we can forward it
 // directly and it doesn't need to go through the region.
-struct WarpOpForwardOperand
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpForwardOperand : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -923,7 +916,7 @@ struct WarpOpForwardOperand
   }
 };
 
-struct WarpOpBroadcast : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpBroadcast : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -960,7 +953,7 @@ struct WarpOpBroadcast : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
-struct WarpOpShapeCast : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpShapeCast : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1019,7 +1012,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 /// %cmp = arith.cmpi ult, %laneid, %0
 /// %ub = arith.select %cmp, %c0, %c1
 /// %1 = vector.create_mask %ub : vector<1xi1>
-struct WarpOpCreateMask : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpCreateMask : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1085,7 +1078,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
-struct WarpOpExtract : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpExtract : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1165,12 +1158,10 @@ struct WarpOpExtract : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 
 /// Pattern to move out vector.extract with a scalar result.
 /// Only supports 1-D and 0-D sources for now.
-struct WarpOpExtractScalar
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpExtractScalar : public WarpDistributionPattern {
   WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
                       PatternBenefit b = 1)
-      : WarpDistributionPattern<WarpExecuteOnLane0Op>(ctx, b),
-        warpShuffleFromIdxFn(std::move(fn)) {}
+      : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1265,8 +1256,7 @@ struct WarpOpExtractScalar
 };
 
 /// Pattern to convert vector.extractelement to vector.extract.
-struct WarpOpExtractElement
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpExtractElement : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1288,8 +1278,7 @@ struct WarpOpExtractElement
 
 /// Pattern to move out vector.insert with a scalar input.
 /// Only supports 1-D and 0-D destinations for now.
-struct WarpOpInsertScalar
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpInsertScalar : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1381,7 +1370,7 @@ struct WarpOpInsertScalar
   }
 };
 
-struct WarpOpInsert : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpInsert : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1494,8 +1483,7 @@ struct WarpOpInsert : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
   }
 };
 
-struct WarpOpInsertElement
-    : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpInsertElement : public WarpDistributionPattern {
   using WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
@@ -1547,11 +1535,10 @@ struct WarpOpInsertElement
 ///     scf.yield %iw : vector<4xf32>
 ///  }
 /// ```
-struct WarpOpScfForOp : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpScfForOp : public WarpDistributionPattern {
 
   WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
-      : WarpDistributionPattern<WarpExecuteOnLane0Op>(ctx, b),
-        distributionMapFn(std::move(fn)) {}
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     auto yield = cast<gpu::YieldOp>(
@@ -1690,11 +1677,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
 /// %r = ("warp.reduction %a")
 /// ```
-struct WarpOpReduction : public WarpDistributionPattern<WarpExecuteOnLane0Op> {
+struct WarpOpReduction : public WarpDistributionPattern {
   WarpOpReduction(MLIRContext *context,
                   DistributedReductionFn distributedReductionFn,
                   PatternBenefit benefit = 1)
-      : WarpDistributionPattern<WarpExecuteOnLane0Op>(context, benefit),
+      : WarpDistributionPattern(context, benefit),
         distributedReductionFn(std::move(distributedReductionFn)) {}
 
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,

>From 0e763780da6db57301adbd88efa7ec3bce6dea15 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 12 Dec 2024 12:39:43 +0000
Subject: [PATCH 5/9] cleanup

---
 mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
index c62f01f6435dda..2b925778c02d70 100644
--- a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -20,10 +20,6 @@
 
 namespace mlir {
 namespace gpu {
-/// Move scalar operations with no dependency on the warp op outside of the
-/// region.
-void moveScalarUniformCode(gpu::WarpExecuteOnLane0Op op);
-
 struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
   virtual LogicalResult

>From e73360042dd56962378a4fb9a7bb3d9b425c2ea8 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 12 Dec 2024 17:41:31 +0000
Subject: [PATCH 6/9] address review comments

---
 .../Dialect/GPU/Utils/DistributionUtils.h     | 23 ++++++++++---------
 .../Dialect/GPU/Utils/DistributionUtils.cpp   |  9 ++++----
 .../Vector/Transforms/VectorDistribute.cpp    | 10 ++++----
 3 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
index 2b925778c02d70..655ebeecb73d6e 100644
--- a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -22,6 +22,7 @@ namespace mlir {
 namespace gpu {
 struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
   virtual LogicalResult
   matchAndRewrite(WarpExecuteOnLane0Op op,
                   PatternRewriter &rewriter) const override = 0;
@@ -29,31 +30,31 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
 protected:
   /// Return a value yielded by `warpOp` which statifies the filter lamdba
   /// condition and is not dead.
-  static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                                  const std::function<bool(Operation *)> &fn);
+  OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                           const std::function<bool(Operation *)> &fn) const;
 
   /// Helper to create a new WarpExecuteOnLane0Op with different signature.
-  static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
+  WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
       RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-      ValueRange newYieldedValues, TypeRange newReturnTypes);
+      ValueRange newYieldedValues, TypeRange newReturnTypes) const;
 
   /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
   /// `indices` return the index of each new output.
-  static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
+  WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
       RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
       ValueRange newYieldedValues, TypeRange newReturnTypes,
-      llvm::SmallVector<size_t> &indices);
+      llvm::SmallVector<size_t> &indices) const;
 
   /// Delinearize the given `laneId` into multiple dimensions, where each
   /// dimension's size is determined by `originalShape` and `distributedShape`
   /// together. This function expects the total numbers of threads needed for
   /// distribution is equal to `warpSize`. Returns true and updates
   /// `delinearizedIds` if so.
-  static bool delinearizeLaneId(OpBuilder &builder, Location loc,
-                                ArrayRef<int64_t> originalShape,
-                                ArrayRef<int64_t> distributedShape,
-                                int64_t warpSize, Value laneId,
-                                SmallVectorImpl<Value> &delinearizedIds);
+  bool delinearizeLaneId(OpBuilder &builder, Location loc,
+                         ArrayRef<int64_t> originalShape,
+                         ArrayRef<int64_t> distributedShape, int64_t warpSize,
+                         Value laneId,
+                         SmallVectorImpl<Value> &delinearizedIds) const;
 };
 
 } // namespace gpu
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 857c1e84f303f8..4ff9a445f11702 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -23,7 +23,7 @@ using namespace mlir::gpu;
 WarpExecuteOnLane0Op
 WarpDistributionPattern::moveRegionToNewWarpOpAndReplaceReturns(
     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
-    ValueRange newYieldedValues, TypeRange newReturnTypes) {
+    ValueRange newYieldedValues, TypeRange newReturnTypes) const {
   // Create a new op before the existing one, with the extra operands.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(warpOp);
@@ -51,7 +51,7 @@ WarpExecuteOnLane0Op
 WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) {
+    llvm::SmallVector<size_t> &indices) const {
   SmallVector<Type> types(warpOp.getResultTypes().begin(),
                           warpOp.getResultTypes().end());
   auto yield = cast<gpu::YieldOp>(
@@ -82,7 +82,8 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
 }
 
 OpOperand *WarpDistributionPattern::getWarpResult(
-    WarpExecuteOnLane0Op warpOp, const std::function<bool(Operation *)> &fn) {
+    WarpExecuteOnLane0Op warpOp,
+    const std::function<bool(Operation *)> &fn) const {
   auto yield = cast<gpu::YieldOp>(
       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
   for (OpOperand &yieldOperand : yield->getOpOperands()) {
@@ -99,7 +100,7 @@ OpOperand *WarpDistributionPattern::getWarpResult(
 bool WarpDistributionPattern::delinearizeLaneId(
     OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape,
     ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId,
-    SmallVectorImpl<Value> &delinearizedIds) {
+    SmallVectorImpl<Value> &delinearizedIds) const {
   // If the original shape and the distributed shape is the same, we don't
   // distribute at all--every thread is handling the whole. For such case, we
   // should not rely on lane IDs later. So just return an empty lane ID vector.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 068a685c29c65c..20bd721bc34e73 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -536,11 +536,11 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
   /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
   /// execute op with the proper return type. The new write op is updated to
   /// write the result of the new warp execute op. The old `writeOp` is deleted.
-  static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
-                                              WarpExecuteOnLane0Op warpOp,
-                                              vector::TransferWriteOp writeOp,
-                                              VectorType targetType,
-                                              VectorType maybeMaskType) {
+  vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
+                                       WarpExecuteOnLane0Op warpOp,
+                                       vector::TransferWriteOp writeOp,
+                                       VectorType targetType,
+                                       VectorType maybeMaskType) const {
     assert(writeOp->getParentOp() == warpOp &&
            "write must be nested immediately under warp");
     OpBuilder::InsertionGuard g(rewriter);

>From e57e8c358bb0bc2f446ce0e50fe3faf2dcf1afca Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 12 Dec 2024 18:02:45 +0000
Subject: [PATCH 7/9] address more comments

---
 .../Dialect/GPU/Utils/DistributionUtils.h     | 21 ++++++++-----------
 .../Dialect/GPU/Utils/DistributionUtils.cpp   | 14 ++++++-------
 2 files changed, 16 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
index 655ebeecb73d6e..0bed513ceb805c 100644
--- a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -1,4 +1,4 @@
-//===- VectorDistributionUtils.h - Distribution Utilities -------*- C++ -*-===//
+//===- DistributionUtils.h - Distribution Utilities -------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -15,13 +15,10 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
 
-#include <numeric>
-#include <utility>
-
-namespace mlir {
-namespace gpu {
+namespace mlir::gpu {
 struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  using OpRewritePattern::OpRewritePattern;
+  using Base = WarpDistributionPattern;
 
   virtual LogicalResult
   matchAndRewrite(WarpExecuteOnLane0Op op,
@@ -30,8 +27,9 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
 protected:
   /// Return a value yielded by `warpOp` which statifies the filter lamdba
   /// condition and is not dead.
-  OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
-                           const std::function<bool(Operation *)> &fn) const;
+  OpOperand *
+  getWarpResult(WarpExecuteOnLane0Op warpOp,
+                const llvm::function_ref<bool(Operation *)> fn) const;
 
   /// Helper to create a new WarpExecuteOnLane0Op with different signature.
   WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
@@ -43,7 +41,7 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
   WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
       RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
       ValueRange newYieldedValues, TypeRange newReturnTypes,
-      llvm::SmallVector<size_t> &indices) const;
+      SmallVector<size_t> &indices) const;
 
   /// Delinearize the given `laneId` into multiple dimensions, where each
   /// dimension's size is determined by `originalShape` and `distributedShape`
@@ -57,7 +55,6 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
                          SmallVectorImpl<Value> &delinearizedIds) const;
 };
 
-} // namespace gpu
-} // namespace mlir
+} // namespace mlir::gpu
 
 #endif // MLIR_DIALECT_GPU_TRANSFORMS_DISTRIBUTIONUTILS_H_
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 4ff9a445f11702..7fa2e7b28c8320 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -51,22 +51,22 @@ WarpExecuteOnLane0Op
 WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
     ValueRange newYieldedValues, TypeRange newReturnTypes,
-    llvm::SmallVector<size_t> &indices) const {
+    SmallVector<size_t> &indices) const {
   SmallVector<Type> types(warpOp.getResultTypes().begin(),
                           warpOp.getResultTypes().end());
   auto yield = cast<gpu::YieldOp>(
       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
   llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
                                               yield.getOperands().end());
-  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
-    if (yieldValues.insert(std::get<0>(newRet))) {
-      types.push_back(std::get<1>(newRet));
+  for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
+    if (yieldValues.insert(value)) {
+      types.push_back(type);
       indices.push_back(yieldValues.size() - 1);
     } else {
       // If the value already exit the region don't create a new output.
       for (auto [idx, yieldOperand] :
            llvm::enumerate(yieldValues.getArrayRef())) {
-        if (yieldOperand == std::get<0>(newRet)) {
+        if (yieldOperand == value) {
           indices.push_back(idx);
           break;
         }
@@ -83,7 +83,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
 
 OpOperand *WarpDistributionPattern::getWarpResult(
     WarpExecuteOnLane0Op warpOp,
-    const std::function<bool(Operation *)> &fn) const {
+    const llvm::function_ref<bool(Operation *)> fn) const {
   auto yield = cast<gpu::YieldOp>(
       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
   for (OpOperand &yieldOperand : yield->getOpOperands()) {
@@ -94,7 +94,7 @@ OpOperand *WarpDistributionPattern::getWarpResult(
         return &yieldOperand;
     }
   }
-  return {};
+  return nullptr;
 }
 
 bool WarpDistributionPattern::delinearizeLaneId(

>From f141a261f41c46acaa9cdedb9099ff9173d300c5 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 12 Dec 2024 19:02:17 +0000
Subject: [PATCH 8/9] remove const

---
 mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h | 5 ++---
 mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp        | 2 +-
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
index 0bed513ceb805c..ff8840a769779f 100644
--- a/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
+++ b/mlir/include/mlir/Dialect/GPU/Utils/DistributionUtils.h
@@ -27,9 +27,8 @@ struct WarpDistributionPattern : OpRewritePattern<WarpExecuteOnLane0Op> {
 protected:
   /// Return a value yielded by `warpOp` which statifies the filter lamdba
   /// condition and is not dead.
-  OpOperand *
-  getWarpResult(WarpExecuteOnLane0Op warpOp,
-                const llvm::function_ref<bool(Operation *)> fn) const;
+  OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
+                           llvm::function_ref<bool(Operation *)> fn) const;
 
   /// Helper to create a new WarpExecuteOnLane0Op with different signature.
   WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 7fa2e7b28c8320..9d51ac3fc4bdce 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -83,7 +83,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
 
 OpOperand *WarpDistributionPattern::getWarpResult(
     WarpExecuteOnLane0Op warpOp,
-    const llvm::function_ref<bool(Operation *)> fn) const {
+    llvm::function_ref<bool(Operation *)> fn) const {
   auto yield = cast<gpu::YieldOp>(
       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
   for (OpOperand &yieldOperand : yield->getOpOperands()) {

>From cc5e22e7ea8649695f515d7bb53fd9594a7f26f4 Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 12 Dec 2024 20:46:52 +0000
Subject: [PATCH 9/9] address review comments

---
 .../Vector/Transforms/VectorDistribute.cpp    | 26 +++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 20bd721bc34e73..e214257de2cdfd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -588,7 +588,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern {
 /// }
 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
 struct WarpOpElementwise : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
@@ -654,7 +654,7 @@ struct WarpOpElementwise : public WarpDistributionPattern {
 /// }
 /// %0 = arith.constant dense<2.0> : vector<1xf32>
 struct WarpOpConstant : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
@@ -700,7 +700,7 @@ struct WarpOpConstant : public WarpDistributionPattern {
 /// }
 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
 struct WarpOpTransferRead : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     // Try to find a distributable yielded read. Note that this pattern can
@@ -812,7 +812,7 @@ struct WarpOpTransferRead : public WarpDistributionPattern {
 /// Remove any result that has no use along with the matching yieldOp operand.
 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
 struct WarpOpDeadResult : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Type> newResultTypes;
@@ -873,7 +873,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
 // If an operand is directly yielded out of the region we can forward it
 // directly and it doesn't need to go through the region.
 struct WarpOpForwardOperand : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Type> resultTypes;
@@ -917,7 +917,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern {
 };
 
 struct WarpOpBroadcast : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -954,7 +954,7 @@ struct WarpOpBroadcast : public WarpDistributionPattern {
 /// Pattern to move shape cast out of the warp op. shape cast is basically a
 /// no-op for warp distribution; we need to handle the shape though.
 struct WarpOpShapeCast : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1013,7 +1013,7 @@ struct WarpOpShapeCast : public WarpDistributionPattern {
 /// %ub = arith.select %cmp, %c0, %c1
 /// %1 = vector.create_mask %ub : vector<1xi1>
 struct WarpOpCreateMask : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *yieldOperand =
@@ -1079,7 +1079,7 @@ struct WarpOpCreateMask : public WarpDistributionPattern {
 /// Pattern to move out vector.extract of single element vector. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtract : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1257,7 +1257,7 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
 
 /// Pattern to convert vector.extractelement to vector.extract.
 struct WarpOpExtractElement : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =
@@ -1279,7 +1279,7 @@ struct WarpOpExtractElement : public WarpDistributionPattern {
 /// Pattern to move out vector.insert with a scalar input.
 /// Only supports 1-D and 0-D destinations for now.
 struct WarpOpInsertScalar : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
@@ -1371,7 +1371,7 @@ struct WarpOpInsertScalar : public WarpDistributionPattern {
 };
 
 struct WarpOpInsert : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
@@ -1484,7 +1484,7 @@ struct WarpOpInsert : public WarpDistributionPattern {
 };
 
 struct WarpOpInsertElement : public WarpDistributionPattern {
-  using WarpDistributionPattern::WarpDistributionPattern;
+  using Base::Base;
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand =



More information about the Mlir-commits mailing list