[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance multi-reduction layout propagation rules (PR #186308)
Jianhui Li
llvmlistbot at llvm.org
Thu Mar 19 08:49:45 PDT 2026
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/186308
>From af03ed33a0ad6c5afe2a73a17a06751035650b96 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 13 Mar 2026 03:35:06 +0000
Subject: [PATCH 1/2] improve multireduction layout support for high-d and
consumer's nested slice attr
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 135 ++++++++++++------
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 2 -
.../XeGPU/propagate-layout-subgroup.mlir | 39 ++++-
3 files changed, 124 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index feefeb727a732..2f3044df0f1e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -432,75 +432,118 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
return DenseI32ArrayAttr::get(context, vec32);
};
- // Extract original plain layout for workgroup/subgroup size recovery
- xegpu::SliceAttr consumerSliceLayout =
- dyn_cast<xegpu::SliceAttr>(consumerLayout);
- DistributeLayoutAttr plainLayout =
- consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
- : consumerLayout;
+ // Helper lambda to check if the layout from consumer can be reused for the
+ // source shape
+ auto isLayoutCompatibleWithSrcShape =
+ [&](ArrayRef<int64_t> srcShape,
+ xegpu::DistributeLayoutAttr srcLayout) -> bool {
+ SmallVector<int64_t> sgLayout = srcLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> laneLayout = srcLayout.getEffectiveLaneLayoutAsInt();
+ if (!(srcLayout.getRank() == srcShape.size()))
+ return false;
+ for (int i = 0; i < srcShape.size(); i++) {
+ if (!sgLayout.empty() && srcShape[i] % sgLayout[i] != 0)
+ return false;
+ if (!laneLayout.empty() && srcShape[i] % laneLayout[i] != 0)
+ return false;
+ }
+ return true;
+ };
+ // Extract original plain layout for workgroup/subgroup size recovery
+ xegpu::DistributeLayoutAttr rootPlainLayout = consumerLayout;
+ while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(rootPlainLayout)) {
+ rootPlainLayout = sliceAttr.getParent();
+ }
+ auto sgLayoutVec = rootPlainLayout.getEffectiveSgLayoutAsInt();
+ const int workgroupSize = std::accumulate(
+ sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
const int subgroupSize = uArch->getSubgroupSize();
int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
- xegpu::DistributeLayoutAttr srcLayout;
+ xegpu::SliceAttr consumerSliceLayout =
+ dyn_cast<xegpu::SliceAttr>(consumerLayout);
+ SmallVector<int64_t> consumerSgLayout =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> consumerLaneLayout =
+ consumerLayout.getEffectiveLaneLayoutAsInt();
+ SmallVector<int64_t> consumerOrder = consumerLayout.getEffectiveOrderAsInt();
+ DenseI32ArrayAttr orderAttr = consumerLayout.getOrder();
+ xegpu::DistributeLayoutAttr srcLayout;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
- auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
- const int workgroupSize = std::accumulate(
- sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
- SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
- SmallVector<int64_t> consumerSgLayout =
- consumerLayout.getEffectiveSgLayoutAsInt();
- int remainingSgCount = workgroupSize;
- int consumerIdx = consumerSgLayout.size() - 1;
-
- // First pass: Match consumer's layout on non-reduction dimensions
- for (int i = srcRank - 1; i >= 0; i--) {
- if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
- sgLayout[i] = consumerSgLayout[consumerIdx];
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by consumer sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
- remainingSgCount /= sgLayout[i];
- consumerIdx--;
+ if (consumerSliceLayout &&
+ consumerSliceLayout.getDims().asArrayRef().equals(reductionDims) &&
+ isLayoutCompatibleWithSrcShape(srcShape,
+ consumerSliceLayout.getParent())) {
+ int64_t sgDataValue = -1;
+ srcLayout = consumerSliceLayout.getParent();
+ SmallVector<int64_t> sgLayoutFromConsumer =
+ srcLayout.getEffectiveSgLayoutAsInt();
+ for (int dim = 0; dim < srcRank; dim++) {
+ sgDataValue = srcShape[dim] / sgLayoutFromConsumer[dim];
+ srcLayout = srcLayout.setDimData(dim, sgDataValue, -1, -1);
}
- }
+ } else {
- // Second pass: Distribute remaining subgroups across reduction dimensions
- for (int i = srcRank - 1; i >= 0; i--) {
- if (llvm::is_contained(reductionDims, i)) {
- sgLayout[i] =
- std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
- remainingSgCount /= sgLayout[i];
+ SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
+ int remainingSgCount = workgroupSize;
+ int consumerIdx = consumerSgLayout.size() - 1;
+
+ // First pass: Match consumer's layout on non-reduction dimensions
+ for (int i = srcRank - 1; i >= 0; i--) {
+ if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
+ sgLayout[i] = consumerSgLayout[consumerIdx];
+ assert((srcShape[i] % sgLayout[i] == 0) &&
+ "source shape not divisible by consumer sg_layout");
+ sgData[i] = srcShape[i] / sgLayout[i];
+ remainingSgCount /= sgLayout[i];
+ order[i] = consumerOrder[consumerIdx];
+ consumerIdx--;
+ }
}
- }
- assert(remainingSgCount == 1 && "not all subgroups distributed");
- srcLayout = xegpu::LayoutAttr::get(
- context, toInt32Attr(sgLayout), toInt32Attr(sgData),
- /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
- /*lane_data =*/nullptr, /*order =*/nullptr);
+ // Second pass: Distribute remaining subgroups across reduction dimensions
+ int64_t remainOrder = consumerSgLayout.size();
+ for (int i = srcRank - 1; i >= 0; i--) {
+ if (llvm::is_contained(reductionDims, i)) {
+ sgLayout[i] =
+ std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
+ assert((srcShape[i] % sgLayout[i] == 0) &&
+ "source shape not divisible by sg_layout");
+ sgData[i] = srcShape[i] / sgLayout[i];
+ remainingSgCount /= sgLayout[i];
+ order[i] = remainOrder++;
+ }
+ }
+ assert(remainingSgCount == 1 && "not all subgroups distributed");
+ srcLayout = xegpu::LayoutAttr::get(
+ context, toInt32Attr(sgLayout), toInt32Attr(sgData),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/
+ (!orderAttr || orderAttr.empty()) ? nullptr : toInt32Attr(order));
+ }
} else if (layoutKind == xegpu::LayoutKind::InstData) {
SmallVector<int64_t> instData(srcRank, 1);
instData[srcRank - 2] =
std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
- instData[srcRank - 1] = subgroupSize;
+ instData[srcRank - 1] =
+ std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
} else if (layoutKind == xegpu::LayoutKind::Lane) {
SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
- laneLayout[srcRank - 1] = subgroupSize;
+ laneLayout[srcRank - 1] =
+ std::min(static_cast<int64_t>(subgroupSize), srcShape[srcRank - 1]);
laneData[srcRank - 2] =
std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
- srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
- toInt32Attr(laneData),
- consumerLayout.getOrder());
+ srcLayout = xegpu::LayoutAttr::get(
+ context, toInt32Attr(laneLayout), toInt32Attr(laneData),
+ (!orderAttr || orderAttr.empty()) ? nullptr
+ : toInt32Attr(consumerOrder));
}
return xegpu::SliceAttr::get(context, srcLayout,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8bf0f2aca60c5..00008ae344d73 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1136,7 +1136,6 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
if (!hasParamsOfLayoutKind(anchorLayout)) {
VectorType resVecTy =
llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
- assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
const uArch *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
if (!uArch)
return;
@@ -1157,7 +1156,6 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
} else {
VectorType srcVecTy =
llvm::cast<VectorType>(storeMatrix.getData().getType());
- assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
const uArch *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
if (!uArch)
return;
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 39fd815b1b380..d730d04c819fa 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -128,8 +128,7 @@ gpu.module @test {
gpu.module @test {
// CHECK-LABEL: vector_row_reduction
// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [1, 64]>, dims = [1]>}
- gpu.func @vector_row_reduction(%src: memref<32x64xf32>, %dst: memref<32xf32>) kernel attributes
- {known_block_size = array<i32: 1, 32, 1>} {
+ gpu.func @vector_row_reduction(%src: memref<32x64xf32>, %dst: memref<32xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<32xf32>
%tdesc_src = xegpu.create_nd_tdesc %src : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32>
%load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x64xf32> -> vector<32x64xf32>
@@ -144,8 +143,7 @@ gpu.module @test {
// -----
gpu.module @test {
// CHECK-LABEL: vector_nest_reduction
- gpu.func @vector_nest_reduction(%src: memref<32x128xf32>, %dst: memref<32xf32>) kernel attributes
- {known_block_size = array<i32: 1, 32, 1>} {
+ gpu.func @vector_nest_reduction(%src: memref<32x128xf32>, %dst: memref<32xf32>) {
%cst = arith.constant dense<0.000000e+00> : vector<32xf32>
%cst1 = arith.constant dense<0.000000e+00> : vector<32x128xf32>
%tdesc_src = xegpu.create_nd_tdesc %src : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32>
@@ -165,6 +163,39 @@ gpu.module @test {
}
}
+// -----
+gpu.module @test {
+// CHECK-LABEL: vector_nest_reduction_with_nest_slice_layout
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>, dims = [1]>} dense<0.000000e+00> : vector<32xf32>
+// CHECK: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>} dense<0.000000e+00> : vector<32x128xf32>
+// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32, #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>>
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>}>
+// CHECK-SAME: -> vector<32x128xf32>
+// CHECK: %[[BCAST1:.*]] = vector.broadcast %[[LOAD]] {layout_result_0 = #xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>} : vector<32x128xf32> to vector<4x32x128xf32>
+// CHECK: %[[REDUCE1:.*]] = vector.multi_reduction <add>, %[[BCAST1]], %[[CST0]]
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>} [0] : vector<4x32x128xf32> to vector<32x128xf32>
+// CHECK: %[[REDUCE2:.*]] = vector.multi_reduction <add>, %[[REDUCE1]], %[[CST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>, dims = [1]>} [1] : vector<32x128xf32> to vector<32xf32>
+// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>} dense<true> : vector<32xi1>
+// CHECK: %[[OFFSET:.*]] = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>} : vector<32xindex>
+// CHECK: xegpu.store %[[REDUCE2]], %{{.*}}[%[[OFFSET]]], %[[MASK]]
+// CHECK-SAME: <{layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>}>
+// CHECK-SAME: : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
+ gpu.func @vector_nest_reduction_with_nest_slice_layout(%src: memref<32x128xf32>, %dst: memref<32xf32>) {
+ %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+ %cst1 = arith.constant dense<0.000000e+00> : vector<32x128xf32>
+ %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32>
+ %load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x128xf32> -> vector<32x128xf32>
+ %bcast1 = vector.broadcast %load: vector<32x128xf32> to vector<4x32x128xf32>
+ %bcast = vector.multi_reduction <add>, %bcast1, %cst1 [0]: vector<4x32x128xf32> to vector<32x128xf32>
+ %reduce = vector.multi_reduction <add>, %bcast, %cst [1] : vector<32x128xf32> to vector<32xf32>
+ %mask = arith.constant dense<1>: vector<32xi1>
+ %offset = vector.step : vector<32xindex>
+ xegpu.store %reduce, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 32]>, dims = [0]>, dims = [1]>} : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
+ gpu.return
+ }
+}
+
// -----
gpu.module @test {
// CHECK-LABEL: for_loop_dpas
>From 385eee510e7f5a3419c4db935e57f0d2da006c23 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 19 Mar 2026 15:49:24 +0000
Subject: [PATCH 2/2] using computeShapeRtio and enhance layout.getNumSubgroups
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 20 ++++++---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 45 ++++---------------
2 files changed, 23 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index ce0cce65373e5..d16787297b302 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -193,12 +193,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"getRank">,
InterfaceMethod<"Get the num of effective subgroups",
"int64_t",
- "getNumSubgroups", (ins), [{
- std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getEffectiveSgLayoutAsInt();
- if (sgLayout.has_value())
- return computeProduct(*sgLayout);
- return 0;
- }], [{}]>,
+ "getNumSubgroups">,
InterfaceMethod<"Get the order of the layout attribute",
"DenseI32ArrayAttr",
"getOrder">,
@@ -464,6 +459,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return 0;
}
+ int64_t getNumSubgroups() const {
+ auto sgLayout = getEffectiveSgLayoutAsInt();
+ if (!sgLayout.empty())
+ return computeProduct(sgLayout);
+ return 0;
+ }
+
LayoutAttr dropSgLayoutAndData() const{
// avoid every field of the attribute is nullptr, which may lead to segment fault
if (!getInstData() && !getLaneLayout())
@@ -613,6 +615,12 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
return parent.getRank() - attr.getDims().size();
}
+ int64_t getNumSubgroups() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.getRank();
+ }
+
DenseI32ArrayAttr getOrder() const {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 2812efdaee27a..9236e0cea3a02 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -437,37 +437,10 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
return DenseI32ArrayAttr::get(context, vec32);
};
- // Helper lambda to check if the layout from consumer can be reused for the
- // source shape
- auto isLayoutCompatibleWithSrcShape =
- [&](ArrayRef<int64_t> srcShape,
- xegpu::DistributeLayoutAttr srcLayout) -> bool {
- SmallVector<int64_t> sgLayout = srcLayout.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> laneLayout = srcLayout.getEffectiveLaneLayoutAsInt();
- if (static_cast<size_t>(srcLayout.getRank()) != srcShape.size())
- return false;
- for (size_t i = 0; i < srcShape.size(); i++) {
- if (!sgLayout.empty() && srcShape[i] % sgLayout[i] != 0)
- return false;
- if (!laneLayout.empty() && srcShape[i] % laneLayout[i] != 0)
- return false;
- }
- return true;
- };
-
- // Extract original plain layout for workgroup/subgroup size recovery
- xegpu::DistributeLayoutAttr rootPlainLayout = consumerLayout;
- while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(rootPlainLayout)) {
- rootPlainLayout = sliceAttr.getParent();
- }
- auto sgLayoutVec = rootPlainLayout.getEffectiveSgLayoutAsInt();
- const int workgroupSize = std::accumulate(
- sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
+ const int workgroupSize = consumerLayout.getNumSubgroups();
const int subgroupSize = uArch->getSubgroupSize();
int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
- xegpu::SliceAttr consumerSliceLayout =
- dyn_cast<xegpu::SliceAttr>(consumerLayout);
SmallVector<int64_t> consumerSgLayout =
consumerLayout.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> consumerLaneLayout =
@@ -477,18 +450,18 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
xegpu::DistributeLayoutAttr srcLayout;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ xegpu::SliceAttr consumerSliceLayout =
+ dyn_cast<xegpu::SliceAttr>(consumerLayout);
if (consumerSliceLayout &&
- consumerSliceLayout.getDims().asArrayRef().equals(reductionDims) &&
- isLayoutCompatibleWithSrcShape(srcShape,
- consumerSliceLayout.getParent())) {
- int64_t sgDataValue = -1;
+ consumerSliceLayout.getDims().asArrayRef().equals(reductionDims)) {
srcLayout = consumerSliceLayout.getParent();
SmallVector<int64_t> sgLayoutFromConsumer =
srcLayout.getEffectiveSgLayoutAsInt();
- for (int dim = 0; dim < srcRank; dim++) {
- sgDataValue = srcShape[dim] / sgLayoutFromConsumer[dim];
- srcLayout = srcLayout.setDimData(dim, sgDataValue, -1, -1);
- }
+ auto srcSgData = computeShapeRatio(srcShape, sgLayoutFromConsumer);
+ if (srcSgData)
+ for (int dim = 0; dim < srcRank; dim++) {
+ srcLayout = srcLayout.setDimData(dim, srcSgData.value()[dim], -1, -1);
+ }
} else {
SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
More information about the Mlir-commits
mailing list