[Mlir-commits] [mlir] [MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix (PR #162780)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 9 22:16:09 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
This PR adds lowering of xegpu.load_matrix/store_matrix to xevm.blockload/blockstore or and llvm.load/store, depending on wi level attributes.
It includes a few components:
1. adds wi-level attributes: subgroup_block_io, vec_length, and vec_direction.
2. expand load_matrix/store_matrix op definition to support scalar data (besides vector data).
2. adds a member function to mem_desc to compute the linearized address for a nd offsets.
3. add lowering depending on wi-level attributes:
a) if result is scalar, lower to regular llvm.load/store
b) if result is a vector and subgroup_block_io attribute presents, lower to xevm.blockload/blockstore
c) if result is a vector and vec_lenght/vec_direction present, lower to llvm.load/store with vector operand.
---
Patch is 53.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162780.diff
13 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+22)
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+18-8)
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+49-1)
- (modified) mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt (+1)
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+198)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+146)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+68-25)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+2-2)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+1-1)
- (modified) mlir/test/Conversion/XeGPUToXeVM/dpas.mlir (+2-6)
- (added) mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir (+201)
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+30-1)
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+49-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..601e966b49890 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().getAs<ArrayAttr>("stride");
}
+ ArrayAttr getBlockAttr() {
+ return getAttrs().getAs<ArrayAttr>("block");
+ }
+
}];
}
+def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
+def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
+def MatrixAccessDirection :
+ I32EnumAttr<"MatrixAccessDirection",
+ "Matrix elements/vectors can have row or column direction", [
+ RowOriented, ColOriented
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+def MatrixAccessDirectionAttr :
+ EnumAttr<XeGPU_Dialect,
+ MatrixAccessDirection,
+ "matrix_access_direction">{
+ let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061f5debe..044a8ef22d891 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1298,14 +1298,16 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
- AllElementTypesMatch<["mem_desc", "res"]>,
- AllRanksMatch<["mem_desc", "res"]>]> {
+ AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<I32Attr>:$vec_length,
+ OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
- let results = (outs XeGPU_ValueType:$res);
+ let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1336,7 +1338,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
ArrayRef<int64_t> getDataShape() {
- return getRes().getType().getShape();
+ auto resTy = getRes().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1344,13 +1349,15 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- AllElementTypesMatch<["mem_desc", "data"]>,
- AllRanksMatch<["mem_desc", "data"]>]> {
+ AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
- XeGPU_ValueType:$data,
+ AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<I32Attr>:$vec_length,
+ OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1378,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}
ArrayRef<int64_t> getDataShape() {
- return getData().getType().getShape();
+ auto DataTy = getData().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
+ return vecTy.getShape();
+ return {};
}
}];
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84902b2039643..c261fbb576642 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}
- ArrayAttr getStrides() {
+ ArrayAttr getStridesAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
return layout.getStrides();
@@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}
+
+ /// Heuristic to determine if the MemDesc uses column-major layout,
+ /// based on the rank and the value of the first stride dimension.
+ bool isColMajor() {
+ auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
+ return getRank() == 2 && dim0 && dim0.getInt() == 1;
+ }
+
+ // get the Blocking shape for a MemDescType, Which is represented
+ // as an attribute in MemDescType. By default it is the shape
+ // of the mdescTy
+ SmallVector<int64_t> getBlockSize() {
+ SmallVector<int64_t> size(getShape());
+ MemLayoutAttr layout = getMemLayout();
+ if (layout && layout.hasAttr("block")) {
+ ArrayAttr attr = layout.getBlockAttr();
+ size.clear();
+ llvm::for_each(attr, [&](Attribute elem) {
+ if (auto intElem = dyn_cast<IntegerAttr>(elem))
+ size.push_back(intElem.getInt());
+ });
+ }
+ return size;
+ }
+
+ // Get strides as vector of integer.
+ // If it contains block attribute, the strides are blocked strides.
+ //
+ // The blocking is applied against the original matrix shape
+ // so that the linear offset is not impacted by the subview.
+ //
+ // It first computes the original matrix shape using the stride info,
+ // then computes the number of blocks in each dimension of original shape,
+ // then compute the outer block shape and stride,
+ // then combines the inner and outer block shape and stride
+ // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
+ // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
+ // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
+ // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+ SmallVector<int64_t> getStrides();
+
+ /// Generates instructions to compute the linearize offset
+ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+ // the strides of memory descriptor is always considered regardless of blocked or not
+ Value getLinearOffsets(OpBuilder &builder,
+ Location loc, ArrayRef<OpFoldResult> offsets);
+
+
}];
let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b25809f1ed0..dd9edc43a1657 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9ead1d89069d6..67e8246e5536a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/FormatVariadic.h"
@@ -60,6 +61,9 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
return static_cast<int>(xevm::AddrSpace::GLOBAL);
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
+ default:
+ llvm_unreachable("Unknown XeGPU memory space");
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
}
}
@@ -503,6 +507,189 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ TypedValue<MemRefType> src = op.getSource();
+ auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+
+ // Create the result MemRefType with the same shape, element type, and
+ // memory space
+ auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+ Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+ Value(src), zero, ValueRange());
+ rewriter.replaceOp(op, viewOp);
+ return success();
+ }
+};
+
+class MemDescSubviewOpPattern final
+ : public OpConversionPattern<xegpu::MemDescSubviewOp> {
+public:
+ using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return rewriter.notifyMatchFailure(
+ op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ Value basePtrStruct = adaptor.getMemDesc();
+ Value mdescVal = op.getMemDesc();
+ // Load result or Store value Type can be vector or scalar.
+ Value data;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+ data = op.getResult();
+ else
+ data = adaptor.getData();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ if (!valOrResVecTy)
+ valOrResVecTy = VectorType::get(1, data.getType());
+
+ int64_t elemBitWidth =
+ valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+
+ // Default memory space is SLM.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+ auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, basePtrStruct);
+
+ // Convert base pointer (ptr) to i64
+ Value basePtrI64 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+
+ Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+ linearOffset = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), linearOffset);
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+
+ // convert base pointer (i64) to LLVM pointer type
+ basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+
+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+ // operation. LLVM load/store does not support vector of size 1, so we need
+ // to handle this case separately.
+ if (valOrResVecTy.getNumElements() == 1) {
+ Type scalarTy = valOrResVecTy.getElementType();
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+ basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ } else {
+ // if the attribute 'subgroup_block_io' is set to true, it lowers to
+ // xevm.blockload
+ auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
+ bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
+
+ // BlockLoadOp only supports integer types, so we need to bitcast
+ // Get integer type with matching bit width
+ Type elemTy = valOrResVecTy.getElementType();
+ int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
+ Type intElemTy = rewriter.getIntegerType(bitWidth);
+ VectorType intVecTy =
+ VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
+ if (subgroup_block_io) {
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+ if (intVecTy != valOrResVecTy) {
+ loadOp =
+ vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+ }
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ Value dataToStore = adaptor.getData();
+ if (valOrResVecTy != intVecTy) {
+ dataToStore =
+ vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+ }
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+ nullptr);
+ rewriter.eraseOp(op);
+ }
+ } else {
+ // if the result is 1D vector, if the vector direction is Column, then
+ // the
+ // memory descriptor should be treated as column major
+ auto chipOpt = xegpu::getChipStr(op);
+ if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+ // the lowering only works for pvc and bmg
+ return rewriter.notifyMatchFailure(
+ op, "The lowering is specific to pvc or bmg.");
+ }
+ xegpu::MatrixAccessDirectionAttr vecDirection =
+ op.getVecDirectionAttr();
+ if (vecDirection &&
+ vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
+ !mdescTy.isColMajor())
+ return rewriter.notifyMatchFailure(
+ op, "mem_desc should be column major when "
+ "vec_direction is COLUMN for 1D result.");
+ if (vecDirection &&
+ vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
+ mdescTy.isColMajor())
+ return rewriter.notifyMatchFailure(
+ op, "mem_desc should be row major when "
+ "vec_direction is ROW for 1D result.");
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+ basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ }
+ }
+ return success();
+ }
+};
+
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -785,6 +972,13 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
+ // Convert MemDescType into flattened MemRefType for SLM
+ typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+ Type elemTy = type.getElementType();
+ int numElems = type.getNumElements();
+ return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ });
+
typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
return IntegerType::get(&getContext(), 64);
@@ -919,6 +1113,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
+ patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+ LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+ CreateMemDescOpPattern, MemDescSubviewOpPattern>(
+ typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..cccc8fab4adbc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -726,6 +726,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
}
printer << ">";
}
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+ OpBuilder &builder) {
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b) \
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b) \
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b) \
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<int64_t> blockShape) {
+
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/162780
More information about the Mlir-commits
mailing list