[Mlir-commits] [mlir] [MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix (PR #162780)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 9 22:16:09 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

This PR adds lowering of xegpu.load_matrix/store_matrix to xevm.blockload/blockstore or and llvm.load/store, depending on wi level attributes. 
It includes a few components: 
   1. adds wi-level attributes: subgroup_block_io, vec_length, and vec_direction.   
   2. expand load_matrix/store_matrix op definition to support scalar data (besides vector data). 
   2. adds a member function to mem_desc to compute the linearized address for a nd offsets.  
   3. add lowering depending on wi-level attributes: 
       a) if result is scalar, lower to regular llvm.load/store
       b) if result is a vector and subgroup_block_io attribute presents, lower to xevm.blockload/blockstore
       c) if result is a vector and vec_lenght/vec_direction present, lower to llvm.load/store with vector operand.  
 
  

---

Patch is 53.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162780.diff


13 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+22) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+18-8) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+49-1) 
- (modified) mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+198) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+146) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+68-25) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+2-2) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+1-1) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/dpas.mlir (+2-6) 
- (added) mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir (+201) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+30-1) 
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+49-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..601e966b49890 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
       return getAttrs().getAs<ArrayAttr>("stride");
     }
 
+    ArrayAttr getBlockAttr() {
+      return getAttrs().getAs<ArrayAttr>("block");
+    }
+
   }];
 
 }
 
+def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
+def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
+def MatrixAccessDirection : 
+    I32EnumAttr<"MatrixAccessDirection", 
+    "Matrix elements/vectors can have row or column direction", [
+    RowOriented, ColOriented
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::xegpu";
+}
+def MatrixAccessDirectionAttr : 
+    EnumAttr<XeGPU_Dialect, 
+    MatrixAccessDirection, 
+    "matrix_access_direction">{
+  let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061f5debe..044a8ef22d891 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1298,14 +1298,16 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
 }
 
 def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
-                              AllElementTypesMatch<["mem_desc", "res"]>,
-                              AllRanksMatch<["mem_desc", "res"]>]>  {
+                              AllElementTypesMatch<["mem_desc", "res"]>]>  {
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<I32Attr>:$vec_length,
+    OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+    OptionalAttr<UnitAttr>:$subgroup_block_io,
     OptionalAttr<DistributeLayoutAttr>:$layout
   );
-  let results = (outs XeGPU_ValueType:$res);
+  let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);  
   let assemblyFormat = [{
     $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1336,7 +1338,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
     }
 
     ArrayRef<int64_t> getDataShape() {
-      return getRes().getType().getShape();
+      auto resTy = getRes().getType();
+      if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
+        return vecTy.getShape();
+      return {};
     }
   }];
 
@@ -1344,13 +1349,15 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 }
 
 def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
-                              AllElementTypesMatch<["mem_desc", "data"]>,
-                              AllRanksMatch<["mem_desc", "data"]>]> {
+                              AllElementTypesMatch<["mem_desc", "data"]>]> {
   let arguments = (ins
-    XeGPU_ValueType:$data,
+    AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<I32Attr>:$vec_length,
+    OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+    OptionalAttr<UnitAttr>:$subgroup_block_io,
     OptionalAttr<DistributeLayoutAttr>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1378,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     }
 
     ArrayRef<int64_t> getDataShape() {
-      return getData().getType().getShape();
+      auto DataTy = getData().getType();
+      if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
+        return vecTy.getShape();
+      return {};
     }
 
   }];
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84902b2039643..c261fbb576642 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
       return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
     }
 
-    ArrayAttr getStrides() {
+    ArrayAttr getStridesAttr() {
       auto layout = getMemLayout();
       if (layout && layout.hasAttr("stride")) {
         return layout.getStrides();
@@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
       Builder builder(getContext());
       return builder.getI64ArrayAttr(defaultStrides);
     }
+
+    /// Heuristic to determine if the MemDesc uses column-major layout,
+    /// based on the rank and the value of the first stride dimension.
+    bool isColMajor() {
+      auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
+      return getRank() == 2 && dim0 && dim0.getInt() == 1;
+    }
+
+    // get the Blocking shape for a MemDescType, Which is represented
+    // as an attribute in MemDescType. By default it is the shape
+    // of the mdescTy
+    SmallVector<int64_t> getBlockSize() {
+      SmallVector<int64_t> size(getShape());
+      MemLayoutAttr layout = getMemLayout();
+      if (layout && layout.hasAttr("block")) {
+        ArrayAttr attr = layout.getBlockAttr();
+        size.clear();
+        llvm::for_each(attr, [&](Attribute elem) {
+          if (auto intElem = dyn_cast<IntegerAttr>(elem))
+            size.push_back(intElem.getInt());
+        });
+      }
+      return size;
+    }
+
+    // Get strides as vector of integer. 
+    // If it contains block attribute, the strides are blocked strides.
+    //
+    // The blocking is applied against the original matrix shape
+    // so that the linear offset is not impacted by the subview.
+    //
+    // It first computes the original matrix shape using the stride info,
+    // then computes the number of blocks in each dimension of original shape,
+    // then compute the outer block shape and stride,
+    // then combines the inner and outer block shape and stride
+    // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
+    // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
+    // for  mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
+    // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+    SmallVector<int64_t> getStrides();
+    
+    /// Generates instructions to compute the linearize offset 
+    //  if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+    //  the strides of memory descriptor is always considered regardless of blocked or not
+    Value getLinearOffsets(OpBuilder &builder,
+               Location loc, ArrayRef<OpFoldResult> offsets);
+
+
   }];
 
   let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b25809f1ed0..dd9edc43a1657 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
   MLIRIndexDialect
   MLIRSCFDialect
   MLIRXeGPUDialect
+  MLIRXeGPUUtils
   MLIRPass
   MLIRTransforms
   MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9ead1d89069d6..67e8246e5536a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -60,6 +61,9 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
     return static_cast<int>(xevm::AddrSpace::GLOBAL);
   case xegpu::MemorySpace::SLM:
     return static_cast<int>(xevm::AddrSpace::SHARED);
+  default:
+    llvm_unreachable("Unknown XeGPU memory space");
+    return static_cast<int>(xevm::AddrSpace::GLOBAL);
   }
 }
 
@@ -503,6 +507,189 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
   }
 };
 
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+    : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+  using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    TypedValue<MemRefType> src = op.getSource();
+    auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+
+    // Create the result MemRefType with the same shape, element type, and
+    // memory space
+    auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+    Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+    auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+                                         Value(src), zero, ValueRange());
+    rewriter.replaceOp(op, viewOp);
+    return success();
+  }
+};
+
+class MemDescSubviewOpPattern final
+    : public OpConversionPattern<xegpu::MemDescSubviewOp> {
+public:
+  using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return rewriter.notifyMatchFailure(
+        op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
+  }
+};
+
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    Value basePtrStruct = adaptor.getMemDesc();
+    Value mdescVal = op.getMemDesc();
+    // Load result or Store value Type can be vector or scalar.
+    Value data;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+      data = op.getResult();
+    else
+      data = adaptor.getData();
+    VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+    if (!valOrResVecTy)
+      valOrResVecTy = VectorType::get(1, data.getType());
+
+    int64_t elemBitWidth =
+        valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+    // Element type must be multiple of 8 bits.
+    if (elemBitWidth % 8 != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+    int64_t elemByteSize = elemBitWidth / 8;
+
+    // Default memory space is SLM.
+    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+    auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+    Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, loc, basePtrStruct);
+
+    // Convert base pointer (ptr) to i64
+    Value basePtrI64 = arith::IndexCastUIOp::create(
+        rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+
+    Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+    linearOffset = arith::IndexCastUIOp::create(
+        rewriter, loc, rewriter.getI64Type(), linearOffset);
+    basePtrI64 =
+        addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+
+    // convert base pointer (i64) to LLVM pointer type
+    basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+
+    // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+    // operation. LLVM load/store does not support vector of size 1, so we need
+    // to handle this case separately.
+    if (valOrResVecTy.getNumElements() == 1) {
+      Type scalarTy = valOrResVecTy.getElementType();
+      if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+        Value loadOp =
+            LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+        rewriter.replaceOp(op, loadOp);
+      } else {
+        auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+                                             basePtrLLVM);
+        rewriter.eraseOp(op);
+      }
+      return success();
+    } else {
+      // if the attribute 'subgroup_block_io' is set to true, it lowers to
+      // xevm.blockload
+      auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
+      bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
+
+      // BlockLoadOp only supports integer types, so we need to bitcast
+      // Get integer type with matching bit width
+      Type elemTy = valOrResVecTy.getElementType();
+      int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
+      Type intElemTy = rewriter.getIntegerType(bitWidth);
+      VectorType intVecTy =
+          VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
+      if (subgroup_block_io) {
+        if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+          Value loadOp =
+              xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+          if (intVecTy != valOrResVecTy) {
+            loadOp =
+                vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+          }
+          rewriter.replaceOp(op, loadOp);
+        } else {
+          Value dataToStore = adaptor.getData();
+          if (valOrResVecTy != intVecTy) {
+            dataToStore =
+                vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+          }
+          xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+                                     nullptr);
+          rewriter.eraseOp(op);
+        }
+      } else {
+        // if the result is 1D vector, if the vector direction is Column, then
+        // the
+        //  memory descriptor should be treated as column major
+        auto chipOpt = xegpu::getChipStr(op);
+        if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+          // the lowering only works for pvc and bmg
+          return rewriter.notifyMatchFailure(
+              op, "The lowering is specific to pvc or bmg.");
+        }
+        xegpu::MatrixAccessDirectionAttr vecDirection =
+            op.getVecDirectionAttr();
+        if (vecDirection &&
+            vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
+            !mdescTy.isColMajor())
+          return rewriter.notifyMatchFailure(
+              op, "mem_desc should be column major when "
+                  "vec_direction is COLUMN for 1D result.");
+        if (vecDirection &&
+            vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
+            mdescTy.isColMajor())
+          return rewriter.notifyMatchFailure(
+              op, "mem_desc should be row major when "
+                  "vec_direction is ROW for 1D result.");
+
+        if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+          Value loadOp =
+              LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+          rewriter.replaceOp(op, loadOp);
+        } else {
+          auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+                                               basePtrLLVM);
+          rewriter.eraseOp(op);
+        }
+      }
+    }
+    return success();
+  }
+};
+
 class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -785,6 +972,13 @@ struct ConvertXeGPUToXeVMPass
       auto i32Type = IntegerType::get(&getContext(), 32);
       return VectorType::get(8, i32Type);
     });
+    // Convert MemDescType into flattened MemRefType for SLM
+    typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+      Type elemTy = type.getElementType();
+      int numElems = type.getNumElements();
+      return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+    });
+
     typeConverter.addConversion([&](MemRefType type) -> Type {
       // Convert MemRefType to i64 type.
       return IntegerType::get(&getContext(), 64);
@@ -919,6 +1113,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
                LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
                LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
       typeConverter, patterns.getContext());
+  patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+               LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+               CreateMemDescOpPattern, MemDescSubviewOpPattern>(
+      typeConverter, patterns.getContext());
   patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
                                                       patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..cccc8fab4adbc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -726,6 +726,152 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
   }
   printer << ">";
 }
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+                      OpBuilder &builder) {
+  auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+  auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+  return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b)                                                              \
+  genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b)                                                              \
+  genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b)                                                              \
+  genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+                                            ArrayRef<OpFoldResult> offsets,
+                                            ArrayRef<int64_t> blockShape) {
+
+  assert(offsets.size() == blockShape.size() &&
+         "offsets and blockShape must have the same size");
+  SmallVector<OpFoldResult> blockedOffsets;
+  SmallVector<OpFoldResult> divs, rems;
+
+  for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+    divs.push_back(div(offset, block));
+    rems.push_back(...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/162780


More information about the Mlir-commits mailing list