[Mlir-commits] [mlir] [mlir][SPIRV] Add support for dense_resource in arith to spirv (PR #91318)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 17 01:28:06 PDT 2024


https://github.com/maxbartel updated https://github.com/llvm/llvm-project/pull/91318

>From 1e242dbb691c8585f309ae06f17bf7301831fa7c Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <bartel at roofline.ai>
Date: Tue, 7 May 2024 12:18:26 +0200
Subject: [PATCH 1/2] [mlir][SPIRV] Add support for dense_resource in arith to
 spirv

---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 22 +++++++++--
 .../arith-to-spirv-le-specific.mlir           | 38 +++++++++++++++++++
 2 files changed, 56 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-le-specific.mlir

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 806981728561c..9a1808354e161 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
@@ -229,16 +230,29 @@ struct ConstantCompositeOpPattern final
     if (!srcType || srcType.getNumElements() == 1)
       return failure();
 
-    // arith.constant should only have vector or tenor types.
+    // arith.constant should only have vector or tensor types.
     assert((isa<VectorType, RankedTensorType>(srcType)));
 
     Type dstType = getTypeConverter()->convertType(srcType);
     if (!dstType)
       return failure();
 
-    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
-    if (!dstElementsAttr)
-      return failure();
+    // Import the resource into the IR to make use of the special handling of
+    // element types later on.
+    mlir::DenseElementsAttr dstElementsAttr;
+    if (auto denseElementsAttr =
+            dyn_cast<DenseElementsAttr>(constOp.getValue())) {
+      dstElementsAttr = denseElementsAttr;
+    } else if (auto resourceAttr =
+                   dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
+
+      ArrayRef<char> ptr = resourceAttr.getRawHandle().getBlob()->getData();
+      dstElementsAttr =
+          DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
+    } else {
+      return rewriter.notifyMatchFailure(constOp,
+                                         "Could not decode ElementsAttr");
+    }
 
     ShapedType dstAttrType = dstElementsAttr.getType();
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-le-specific.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-le-specific.mlir
new file mode 100644
index 0000000000000..7233a8bfffa9d
--- /dev/null
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-le-specific.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s | FileCheck %s
+
+
+//===----------------------------------------------------------------------===//
+// arith.constant dense_resource
+//
+// The decoding of dense_resource differs between little and big endian
+// machines. At the moment only litte endian is supported.
+// See https://github.com/llvm/llvm-project/issues/63469 for more infos.
+//
+//===----------------------------------------------------------------------===//
+
+// XFAIL: target=s390x-{{.*}}
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
+} {
+func.func @constant_dense_resource() {
+  // CHECK:    %{{.*}} = spirv.Constant dense<[0.203224242, -0.254296064, -0.365104556, -0.469196141, 0.466041982]> : tensor<5xf32> : !spirv.array<5 x f32>
+  %0 = arith.constant dense_resource<dense_resource_test_5xf32> : tensor<5xf32>  
+  // CHECK:    %{{.*}} = spirv.Constant dense<[1, 2]> : vector<2xi32>
+  %1 = arith.constant dense_resource<dense_resource_test_2xi32> : vector<2xi32>  
+  // CHECK:    %{{.*}} = spirv.Constant dense<[0.35476172, 0.351080596, -0.0795008316, 0.366843373]> : tensor<4xf32> : !spirv.array<4 x f32>
+  %2 = arith.constant dense_resource<dense_resource_test_2x2xf32> : tensor<1x2x2xf32>  
+  return
+  }
+}
+// Resources are kept at end of file. New tests should be added above this.
+{-#
+  dialect_resources: {
+    builtin: {
+      dense_resource_test_2xi32: "0x400000000100000002000000",
+      dense_resource_test_5xf32: "0x08000000041A503E183382BEFCEEBABE7A3AF0BE0E9DEE3E",
+      dense_resource_test_2x2xf32: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
+    }
+  }
+#-}

>From fdbadbe5bd64b280d08254e3d22bae8a9fa00445 Mon Sep 17 00:00:00 2001
From: Maximilian Bartel <bartel at roofline.ai>
Date: Wed, 15 May 2024 18:22:15 +0200
Subject: [PATCH 2/2] refactor: check for buffer validity and emit errors
 instead of match failures

---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 22 +++++++++++----
 .../arith-to-spirv-unsupported.mlir           | 28 +++++++++++++++++++
 2 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9a1808354e161..4c3237b24b786 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -230,8 +230,10 @@ struct ConstantCompositeOpPattern final
     if (!srcType || srcType.getNumElements() == 1)
       return failure();
 
-    // arith.constant should only have vector or tensor types.
-    assert((isa<VectorType, RankedTensorType>(srcType)));
+    // arith.constant should only have vector or tensor types. This is a MLIR
+    // wide problem at the moment.
+    if (!isa<VectorType, RankedTensorType>(srcType))
+      return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
 
     Type dstType = getTypeConverter()->convertType(srcType);
     if (!dstType)
@@ -246,12 +248,22 @@ struct ConstantCompositeOpPattern final
     } else if (auto resourceAttr =
                    dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
 
-      ArrayRef<char> ptr = resourceAttr.getRawHandle().getBlob()->getData();
+      AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
+      if (!blob)
+        return constOp->emitError("could not find resource blob");
+
+      ArrayRef<char> ptr = blob->getData();
+
+      // Check that the buffer meets the requirements to get converted to a
+      // DenseElementsAttr
+      bool detectedSplat = false;
+      if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
+        return constOp->emitError("resource is not a valid buffer");
+
       dstElementsAttr =
           DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
     } else {
-      return rewriter.notifyMatchFailure(constOp,
-                                         "Could not decode ElementsAttr");
+      return constOp->emitError("unsupported elements attribute");
     }
 
     ShapedType dstAttrType = dstElementsAttr.getType();
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d85..2512254b443db 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -95,6 +95,34 @@ func.func @unsupported_constant_tensor_2xf64_0() {
   return
 }
 
+// -----
+
+func.func @constant_dense_resource_non_existant() {
+  // expected-error @+2 {{failed to legalize operation 'arith.constant'}}
+  // expected-error @+1 {{could not find resource blob}}
+  %0 = arith.constant dense_resource<non_existant> : tensor<5xf32>  
+  return
+}
+
+// -----
+
+module {
+func.func @constant_dense_resource_invalid_buffer() {
+  // expected-error @+2 {{failed to legalize operation 'arith.constant'}}
+  // expected-error @+1 {{resource is not a valid buffer}}
+  %0 = arith.constant dense_resource<dense_resource_test_2xi32> : vector<2xi32>  
+  return
+  }
+}
+// This is a buffer of wrong type and shape
+{-#
+  dialect_resources: {
+    builtin: {
+      dense_resource_test_2xi32: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E"
+    }
+  }
+#-}
+
 ///===----------------------------------------------------------------------===//
 // Type emulation
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list