[Mlir-commits] [mlir] 6b82fc7 - Fix Transpose Check in MMA.SYNC Path

Manish Gupta llvmlistbot at llvm.org
Mon Apr 10 17:06:09 PDT 2023


Author: Manish Gupta
Date: 2023-04-11T00:02:20Z
New Revision: 6b82fc77c248c9e8de29e4221ee6d99be92148c3

URL: https://github.com/llvm/llvm-project/commit/6b82fc77c248c9e8de29e4221ee6d99be92148c3
DIFF: https://github.com/llvm/llvm-project/commit/6b82fc77c248c9e8de29e4221ee6d99be92148c3.diff

LOG: Fix Transpose Check in MMA.SYNC Path

Differential Revision: https://reviews.llvm.org/D147749

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 19860a197d771..10a6ee43a8f98 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -650,6 +650,34 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
   return success();
 }
 
+/// Check if the loaded matrix operand requires transposed.
+/// Transposed Map Example:
+/// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
+/// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
+///
+/// The code below checks if the output 2D is transposed using a generalized
+/// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
+/// Returns     : true; if m > n, false o.w.
+
+static bool isTransposed(vector::TransferReadOp op) {
+  mlir::AffineMap map = op.getPermutationMap();
+  if (map.getNumResults() != 2) {
+    op->emitError("Expected 2D transfer read");
+  }
+
+  // Output 2D matrix dimensions in the order of d0, d1.
+  auto dM = map.getResult(0);
+  auto dN = map.getResult(1);
+
+  //  Find the position of these expressions in the input.
+  auto exprM = dM.dyn_cast<AffineDimExpr>();
+  auto exprN = dN.dyn_cast<AffineDimExpr>();
+  if (!exprM || !exprN) {
+    op->emitError("Expected to find AffineDimExpr in vector::TransferReadOp");
+  }
+  return exprM.getPosition() > exprN.getPosition();
+}
+
 static LogicalResult
 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                              llvm::DenseMap<Value, Value> &valueMapping) {
@@ -671,9 +699,10 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
     return rewriter.notifyMatchFailure(op, "not mma sync reg info");
   }
 
-  FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
-      *warpMatrixInfo,
-      /*transpose=*/!op.getPermutationMap().isMinorIdentity());
+  FailureOr<nvgpu::LdMatrixParams> params =
+      nvgpu::getLdMatrixParams(*warpMatrixInfo,
+                               /*transpose=*/isTransposed(op));
+
   if (failed(params)) {
     LLVM_DEBUG(
         DBGS()
@@ -700,7 +729,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                                          indices);
   nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
       loc, vectorType, op.getSource(), indices,
-      !op.getPermutationMap().isMinorIdentity(), params->numTiles);
+      /*transpose=*/isTransposed(op), params->numTiles);
   valueMapping[op] = newOp->getResult(0);
   return success();
 }


        


More information about the Mlir-commits mailing list