[Mlir-commits] [mlir] [mlir][xegpu] Improve XeGPU op verification logic for SIMT flavor and update tests. (PR #127920)

Charitha Saumya llvmlistbot at llvm.org
Fri Feb 21 09:50:18 PST 2025


================
@@ -73,41 +77,37 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
 }
 
-// Validations for nd instruction arguments is successful if any of these are
-// true:
-// - tensor descriptor and the output vector shapes exactly match.
-// - tensor descriptor has a sg_map attribute and the distributed vector shape
-//   matches the tensor descriptor shape when scaled using sg_map factors on
-//   each dimension.
-static bool isArgShapesValid(ArrayRef<int64_t> descShape,
-                             ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
-  // Equal shapes with no distribution - no further verification needed.
-  if (descShape == valShape && !sgMap)
-    return true;
-
-  // Unknown distribution - cannot perform operation on partial shape.
+// Helper to validate value shape of LoadNd and StoreNd ops.
+static LogicalResult
+isArgShapesValid(TensorDescType tdescTy, VectorType valueTy,
+                 ArrayRef<int64_t> adjustedTdescShape,
+                 function_ref<InFlightDiagnostic()> emitError) {
+  auto sgMap = tdescTy.getSGMapAttr();
+  auto valueShape = valueTy.getShape();
+  // sg_map not present means IR is in SIMD mode. In this case value shape must
+  // match adjusted tensor descriptor shape.
   if (!sgMap)
-    return false;
-
-  // Invalid rank or mixed rank usage.
-  size_t descRank = descShape.size();
-  if (descRank > 2 || valShape.size() != descRank)
-    return false;
-
-  // For 1D, SG map is guaranteed to be unit size in the outer dimension.
-  // Only take the distribution over the innermost dimension for validation.
-  ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
-  SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
-  if (descRank == 1)
-    mapLayout = {wiLayout.back()};
-
-  for (const auto &[factor, dim, expected] :
-       llvm::zip_equal(mapLayout, valShape, descShape)) {
-    if (factor * dim != expected)
-      return false;
-  }
-
-  return true;
+    return valueShape == adjustedTdescShape
+               ? success()
+               : emitError()
+                     << "Value shape " << makeString(valueShape)
+                     << " is not consistent with tensor descriptor " << tdescTy;
+
+  // sg_map present means IR is in SIMT mode. In this case sg_map determines the
+  // value shape.
+  auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
+  if (failed(expectedValueShapeOrFailure))
+    return emitError() << "Failed to compute distributed vector shape for "
----------------
charithaintc wrote:

with current usage it not possible to trigger this error. we always check if sgMap is present before calling this. So I converted it to a assert. 

https://github.com/llvm/llvm-project/pull/127920


More information about the Mlir-commits mailing list