[Mlir-commits] [mlir] [MLIR][XeGPU] Switch to 1D representation for SIMT code (PR #135116)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Apr 17 10:26:24 PDT 2025


================
@@ -602,51 +633,20 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
-  auto aLayout = getALayoutAttr();
-  auto bLayout = getBLayoutAttr();
-  auto cLayout = getCLayoutAttr();
-
-  // make sure the layout attribute is either set for every available
-  // operand or simply not set at all. C is special, since ACC is optional.
-  auto hasValidLayoutAttrs = [&]() {
-    bool result = (aLayout != nullptr) ^ (bLayout != nullptr);
-    if (hasAcc()) {
-      result |= (aLayout != nullptr) ^ (cLayout != nullptr);
-    }
-    return !result;
-  };
-
-  if (!hasValidLayoutAttrs())
-    return emitOpError(
-        "layout attributes should be either set for all operands (for SIMT "
-        "code) or not set at all (for SIMD code).");
-
-  // query the scope from aLayout (a valid setting).
-  if (aLayout) {
-    // In SIMT mode, All data fragments must be 2D
-    if (lhsRank != 2 || rhsRank != 2 || resRank != 2)
-      return emitOpError("expecting lhs, rhs, and result to be a 2D vector.");
-
-    auto laneLayoutA = aLayout.getLaneLayout();
-    auto laneLayoutB = bLayout.getLaneLayout();
-    auto laneLayoutC = cLayout.getLaneLayout();
-    // Obtain the expanded shapes of the operands and result using lane_layout.
-    // NOTE: For B, get rid of the packed dimension for the expanded shape.
-    SmallVector<int64_t> expandedShapeA = {lhsShape[0] * laneLayoutA[0],
-                                           lhsShape[1] * laneLayoutA[1]};
-    SmallVector<int64_t> expandedShapeB = {
-        rhsShape[0] * rhsShape[1] * laneLayoutB[0], 1 * laneLayoutB[1]};
-    SmallVector<int64_t> expandedShapeC = {resShape[0] * laneLayoutC[0],
-                                           resShape[1] * laneLayoutC[1]};
-    auto bK = expandedShapeB[0];
-    if (bK != expandedShapeA[1])
-      return emitOpError("K-dimension mismatch.");
-    if (expandedShapeA[0] != expandedShapeC[0])
-      return emitOpError("M-dimension mismatch.");
-    if (expandedShapeB[1] != expandedShapeC[1])
-      return emitOpError("N-dimension mismatch.");
-  } else { // For other scopes, operands' shape should match the mxkxn
-           // semantics.
+  if (getAcc() && getAcc().getType() != getResultType())
+    return emitOpError("Expecting the acc type to be the same as result.");
+
+  // SIMT code: the size of the B operand has to be a multiple of 32 bits.
+  // It skips the semantic check since lack of architecture information.
+  // Users need to ensure the correctness.
+  if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
+    auto numElems = getRhsType().getNumElements();
+    auto elemTy = getRhsType().getElementType();
+    auto factor = 32 / elemTy.getIntOrFloatBitWidth();
+    if (numElems % factor != 0)
+      return emitOpError("Expecting B operand to be a multiple of 32 bits.");
+    return success();
+  } else { // SIMD code
----------------
adam-smnk wrote:

nit: you can skip else as the previous case already has returns

https://github.com/llvm/llvm-project/pull/135116


More information about the Mlir-commits mailing list