Skip to content

Commit f493e48

Browse files
committed
Add F4E2M1FN type: import mxfloat.h
1 parent 25c8c9f commit f493e48

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

third_party/tsl/tsl/platform/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ cc_library(
10601060
deps = [
10611061
"@ml_dtypes//:float8",
10621062
"@ml_dtypes//:intn",
1063+
"@ml_dtypes//:mxfloat",
10631064
],
10641065
)
10651066

third_party/tsl/tsl/platform/ml_dtypes.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ limitations under the License.
1616
#ifndef TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_
1717
#define TENSORFLOW_TSL_PLATFORM_ML_DTYPES_H_
1818

19-
#include "ml_dtypes/include/float8.h" // from @ml_dtypes
20-
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
19+
#include "ml_dtypes/include/float8.h" // from @ml_dtypes
20+
#include "ml_dtypes/include/intn.h" // from @ml_dtypes
21+
#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes
2122

2223
namespace tsl {
24+
using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn;
2325
using float8_e3m4 = ::ml_dtypes::float8_e3m4;
2426
using float8_e4m3 = ::ml_dtypes::float8_e4m3;
2527
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;

0 commit comments

Comments
 (0)