[Mlir-commits] [mlir] [mlir][xegpu] Improve XeGPU op verification logic for SIMT flavor and update tests. (PR #127920)
Adam Siemieniuk
llvmlistbot at llvm.org
Fri Feb 21 04:06:09 PST 2025
================
@@ -73,41 +77,37 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
}
-// Validations for nd instruction arguments is successful if any of these are
-// true:
-// - tensor descriptor and the output vector shapes exactly match.
-// - tensor descriptor has a sg_map attribute and the distributed vector shape
-// matches the tensor descriptor shape when scaled using sg_map factors on
-// each dimension.
-static bool isArgShapesValid(ArrayRef<int64_t> descShape,
- ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
- // Equal shapes with no distribution - no further verification needed.
- if (descShape == valShape && !sgMap)
- return true;
-
- // Unknown distribution - cannot perform operation on partial shape.
+// Helper to validate value shape of LoadNd and StoreNd ops.
+static LogicalResult
+isArgShapesValid(TensorDescType tdescTy, VectorType valueTy,
+ ArrayRef<int64_t> adjustedTdescShape,
+ function_ref<InFlightDiagnostic()> emitError) {
+ auto sgMap = tdescTy.getSGMapAttr();
+ auto valueShape = valueTy.getShape();
+ // sg_map not present means IR is in SIMD mode. In this case value shape must
+ // match adjusted tensor descriptor shape.
if (!sgMap)
- return false;
-
- // Invalid rank or mixed rank usage.
- size_t descRank = descShape.size();
- if (descRank > 2 || valShape.size() != descRank)
- return false;
-
- // For 1D, SG map is guaranteed to be unit size in the outer dimension.
- // Only take the distribution over the innermost dimension for validation.
- ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
- SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
- if (descRank == 1)
- mapLayout = {wiLayout.back()};
-
- for (const auto &[factor, dim, expected] :
- llvm::zip_equal(mapLayout, valShape, descShape)) {
- if (factor * dim != expected)
- return false;
- }
-
- return true;
+ return valueShape == adjustedTdescShape
+ ? success()
+ : emitError()
+ << "Value shape " << makeString(valueShape)
+ << " is not consistent with tensor descriptor " << tdescTy;
+
+ // sg_map present means IR is in SIMT mode. In this case sg_map determines the
+ // value shape.
+ auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
+ if (failed(expectedValueShapeOrFailure))
+ return emitError() << "Failed to compute distributed vector shape for "
----------------
adam-smnk wrote:
Just a reminder to add a test case for this too.
https://github.com/llvm/llvm-project/pull/127920
More information about the Mlir-commits
mailing list