[Mlir-commits] [mlir] 91594b5 - [mlir][nvpu] Prevent F32ToTF32 pattern to generate illegal IR

Thomas Raoux llvmlistbot at llvm.org
Mon Aug 15 09:46:29 PDT 2022


Author: Thomas Raoux
Date: 2022-08-15T16:46:18Z
New Revision: 91594b5b985cf144f9cde64cddc534684d644665

URL: https://github.com/llvm/llvm-project/commit/91594b5b985cf144f9cde64cddc534684d644665
DIFF: https://github.com/llvm/llvm-project/commit/91594b5b985cf144f9cde64cddc534684d644665.diff

LOG: [mlir][nvpu] Prevent F32ToTF32 pattern to generate illegal IR

We shouldn't apply this pattern to non F32->F32 mma.sync operations.

Differential Revision: https://reviews.llvm.org/D131902

Added: 
    

Modified: 
    mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
    mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 4ef93b30978a4..d24001c4d28ae 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -42,7 +42,8 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
                                 PatternRewriter &rewrite) const override {
     Location location = op->getLoc();
 
-    if (op->hasAttr(op.getTf32EnabledAttrName()))
+    if (op->hasAttr(op.getTf32EnabledAttrName()) ||
+        !op.getMatrixA().getType().cast<VectorType>().getElementType().isF32())
       return failure();
 
     if (precision == MmaSyncF32Lowering::Unkown)

diff  --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
index a8c72262f101b..80de11fad711b 100644
--- a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
+++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir
@@ -18,3 +18,12 @@ func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: v
   return %d : vector<2x2xf32>
 }
 // -----
+
+// Negative test for non f32 case.
+// CHECK-LABEL: mma_sync_f16
+//   CHECK-NOT: tf32Enabled
+//       CHECK: return
+func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}


        


More information about the Mlir-commits mailing list