[Mlir-commits] [mlir] [mlir][amx] Simplify intrinsic generation (PR #140559)

Tobias Gysi llvmlistbot at llvm.org
Mon May 19 11:53:46 PDT 2025


================
@@ -60,24 +64,168 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
   return success();
 }
 
+/// Get pointer to a memref descriptor.
+/// Optionally, the base pointer can be offset using linearized index computed
+/// from the given indices.
+static Value getBufferPtr(Location loc, MemRefType type, Value buffer,
+                          ValueRange indices,
+                          const LLVMTypeConverter &typeConverter,
+                          RewriterBase &rewriter) {
+  auto [strides, offset] = type.getStridesAndOffset();
+
+  MemRefDescriptor memRefDescriptor(buffer);
+  Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
+
+  int numIndices = indices.size();
+  if (numIndices == 0)
+    return base;
+
+  assert(type.getRank() == numIndices &&
+         "expects number of indices equal to memref rank");
+  Value index;
+  Type indexType = typeConverter.getIndexType();
+  for (int i = 0; i < numIndices; ++i) {
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? memRefDescriptor.stride(rewriter, loc, i)
+              : rewriter.create<LLVM::ConstantOp>(
+                    loc, indexType, rewriter.getIndexAttr(strides[i]));
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+  }
+
+  Type elementPtrType = memRefDescriptor.getElementPtrType();
+  return rewriter.create<LLVM::GEPOp>(
+      loc, elementPtrType, typeConverter.convertType(type.getElementType()),
+      base, index);
+}
+
+/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
+/// dimension directly translates into the number of rows of the tiles.
+/// The second dimensions needs to be scaled by the number of bytes.
+static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
+                                       RewriterBase &rewriter) {
+  Type llvmInt16Type = rewriter.getIntegerType(16);
+  unsigned width = tType.getElementType().getIntOrFloatBitWidth();
+  assert(llvm::isPowerOf2_64(width) && width >= 8);
+  unsigned bytes = width >> 3;
+  auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+  auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
+  return SmallVector<Value>{
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
+      rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
+}
+
+/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
+/// shape may "envelop" the actual tile shape, and may be dynamically sized.
+/// Returns failure if proper stride couldn't be found.
----------------
gysit wrote:

It seems returning failure has been replaced by an assertion?

https://github.com/llvm/llvm-project/pull/140559


More information about the Mlir-commits mailing list