[Mlir-commits] [mlir] 6b82fc7 - Fix Transpose Check in MMA.SYNC Path
Manish Gupta
llvmlistbot at llvm.org
Mon Apr 10 17:06:09 PDT 2023
Author: Manish Gupta
Date: 2023-04-11T00:02:20Z
New Revision: 6b82fc77c248c9e8de29e4221ee6d99be92148c3
URL: https://github.com/llvm/llvm-project/commit/6b82fc77c248c9e8de29e4221ee6d99be92148c3
DIFF: https://github.com/llvm/llvm-project/commit/6b82fc77c248c9e8de29e4221ee6d99be92148c3.diff
LOG: Fix Transpose Check in MMA.SYNC Path
Differential Revision: https://reviews.llvm.org/D147749
Added:
Modified:
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 19860a197d771..10a6ee43a8f98 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -650,6 +650,34 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
return success();
}
+/// Check if the loaded matrix operand requires transposed.
+/// Transposed Map Example:
+/// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
+/// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
+///
+/// The code below checks if the output 2D is transposed using a generalized
+/// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
+/// Returns : true; if m > n, false o.w.
+
+static bool isTransposed(vector::TransferReadOp op) {
+ mlir::AffineMap map = op.getPermutationMap();
+ if (map.getNumResults() != 2) {
+ op->emitError("Expected 2D transfer read");
+ }
+
+ // Output 2D matrix dimensions in the order of d0, d1.
+ auto dM = map.getResult(0);
+ auto dN = map.getResult(1);
+
+ // Find the position of these expressions in the input.
+ auto exprM = dM.dyn_cast<AffineDimExpr>();
+ auto exprN = dN.dyn_cast<AffineDimExpr>();
+ if (!exprM || !exprN) {
+ op->emitError("Expected to find AffineDimExpr in vector::TransferReadOp");
+ }
+ return exprM.getPosition() > exprN.getPosition();
+}
+
static LogicalResult
creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
@@ -671,9 +699,10 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
- FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
- *warpMatrixInfo,
- /*transpose=*/!op.getPermutationMap().isMinorIdentity());
+ FailureOr<nvgpu::LdMatrixParams> params =
+ nvgpu::getLdMatrixParams(*warpMatrixInfo,
+ /*transpose=*/isTransposed(op));
+
if (failed(params)) {
LLVM_DEBUG(
DBGS()
@@ -700,7 +729,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
indices);
nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
loc, vectorType, op.getSource(), indices,
- !op.getPermutationMap().isMinorIdentity(), params->numTiles);
+ /*transpose=*/isTransposed(op), params->numTiles);
valueMapping[op] = newOp->getResult(0);
return success();
}
More information about the Mlir-commits
mailing list