Skip to content

Adds: #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open

Adds: #510

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,44 @@ cc_library(
],
)

cc_library(
name = "gemma_shared_lib",
srcs = [
"gemma/c_api.cc",
"gemma/context.cc",
],
hdrs = [
"gemma/activations.h",
"gemma/c_api.h",
"gemma/context.h",
"gemma/gemma.h",
],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
},
deps = [
":allocator",
":app",
":basics",
":common",
":kv_cache",
":ops",
":threading",
":tokenizer",
":weights",
"//compression:compress",
"//compression:io",
"//compression:sfp",
"//paligemma:image",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
],
)

cc_library(
name = "cross_entropy",
srcs = ["evals/cross_entropy.cc"],
Expand Down
38 changes: 38 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ set(SOURCES
util/threading.h
)

# Add C API sources only when building DLL
if(BUILD_GEMMA_DLL)
list(APPEND SOURCES
gemma/context.h
gemma/context.cc
gemma/c_api.h
gemma/c_api.cc
)
message(STATUS "Including C API files for DLL build")
endif()

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
endif()
Expand All @@ -129,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS libgemma DESTINATION lib)

# Shared Library Target for C# interop
if(BUILD_GEMMA_DLL)
add_library(gemma_shared SHARED ${SOURCES})
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
set_target_properties(gemma_shared PROPERTIES
PREFIX ""
OUTPUT_NAME "gemma"
)
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(gemma_shared PUBLIC ./)
target_link_libraries(gemma_shared PRIVATE
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
)
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(gemma_shared
PRIVATE
GEMMA_EXPORTS
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
)
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
install(TARGETS gemma_shared DESTINATION lib)
install(FILES gemma/c_api.h DESTINATION include/gemma)
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
endif()

# Executable Target

add_executable(gemma gemma/run.cc)
Expand Down
175 changes: 175 additions & 0 deletions GemmaInterop.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
namespace GemmaCpp
{
public class GemmaException : Exception
{
public GemmaException(string message) : base(message) { }
}

public class Gemma : IDisposable
{
private IntPtr _context;
private bool _disposed;

// Optional: Allow setting DLL path
public static string DllPath { get; set; } = "gemma.dll";

[DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)]
private static extern IntPtr LoadLibrary(string lpFileName);

static Gemma()
{
// Load DLL from specified path
if (LoadLibrary(DllPath) == IntPtr.Zero)
{
throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}");
}
}

[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr GemmaCreate(
[MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string modelType,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath,
[MarshalAs(UnmanagedType.LPUTF8Str)] string weightType,
int maxLength);

[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaDestroy(IntPtr context);

// Delegate type for token callbacks
public delegate bool TokenCallback(string token);

// Keep delegate alive for duration of calls
private GCHandle _callbackHandle;

[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate bool GemmaTokenCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string text,
IntPtr userData);

[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaGenerate(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string prompt,
[MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output,
int maxLength,
GemmaTokenCallback callback,
IntPtr userData);

[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern int GemmaCountTokens(
IntPtr context,
[MarshalAs(UnmanagedType.LPUTF8Str)] string text);

// Native callback delegate type
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate void GemmaLogCallback(
[MarshalAs(UnmanagedType.LPUTF8Str)] string message,
IntPtr userData);

[DllImport("gemma", CallingConvention = CallingConvention.Cdecl)]
private static extern void GemmaSetLogCallback(
IntPtr context,
GemmaLogCallback callback,
IntPtr userData);

private GCHandle _logCallbackHandle;

public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192)
{
_context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength);
if (_context == IntPtr.Zero)
{
throw new GemmaException("Failed to create Gemma context");
}

// optionally: set up logging
/*
GemmaLogCallback logCallback = (message, _) =>
{
#if UNITY_ENGINE
Debug.Log($"Gemma: {message}");
#else
Debug.WriteLine($"Gemma: {message}");
#endif
};
_logCallbackHandle = GCHandle.Alloc(logCallback);
GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
*/
}

public int CountTokens(string prompt)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));

if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");
int count = GemmaCountTokens(_context, prompt);
return count;
}

public string Generate(string prompt, int maxLength = 4096)
{
return Generate(prompt, null, maxLength);
}

public string Generate(string prompt, TokenCallback callback, int maxLength = 4096)
{
if (_disposed)
throw new ObjectDisposedException(nameof(Gemma));

if (_context == IntPtr.Zero)
throw new GemmaException("Gemma context is invalid");

var output = new StringBuilder(maxLength);
GemmaTokenCallback nativeCallback = null;

if (callback != null)
{
nativeCallback = (text, _) => callback(text);
_callbackHandle = GCHandle.Alloc(nativeCallback);
}

try
{
int length = GemmaGenerate(_context, prompt, output, maxLength,
nativeCallback, IntPtr.Zero);

if (length < 0)
throw new GemmaException("Generation failed");

return output.ToString();
}
finally
{
if (_callbackHandle.IsAllocated)
_callbackHandle.Free();
}
}

public void Dispose()
{
if (!_disposed)
{
if (_context != IntPtr.Zero)
{
GemmaDestroy(_context);
_context = IntPtr.Zero;
}
if (_logCallbackHandle.IsAllocated)
_logCallbackHandle.Free();
_disposed = true;
}
}

~Gemma()
{
Dispose();
}
}
}
54 changes: 54 additions & 0 deletions gemma/c_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#ifndef GEMMA_EXPORTS
#define GEMMA_EXPORTS
#endif

#include "gemma/c_api.h"

// necessary as the C API and GemmaContext effectively wrap up and re-use the
// code for the Gemma executable
#include "util/app.h"

extern "C" {

GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path,
const char* weight_type, int max_length) {
try {
// kludge
gcpp::AppArgs app_args;
app_args.Init();
app_args.max_packages = 1;
app_args.verbosity = 0;
app_args.spin = gcpp::Tristate::kFalse;

return new GemmaContext(tokenizer_path, model_type, weights_path,
weight_type, app_args, max_length);
} catch (...) {
return nullptr;
}
}

GEMMA_API void GemmaDestroy(GemmaContext* ctx) {
delete static_cast<gcpp::GemmaContext*>(ctx);
}

GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data) {
if (!ctx) return -1;
return static_cast<gcpp::GemmaContext*>(ctx)->Generate(
prompt, output, max_length, callback, user_data);
}

GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) {
if (!ctx || !text) return -1;
return static_cast<gcpp::GemmaContext*>(ctx)->CountTokens(text);
}

GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data) {
if (!ctx) return;
static_cast<gcpp::GemmaContext*>(ctx)->SetLogCallback(callback, user_data);
}
}
62 changes: 62 additions & 0 deletions gemma/c_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_GEMMA_C_API_H_
#define THIRD_PARTY_GEMMA_C_API_H_

#include "gemma/context.h"

#ifdef _WIN32
#ifdef GEMMA_EXPORTS
#define GEMMA_API __declspec(dllexport)
#else
#define GEMMA_API __declspec(dllimport)
#endif
#else
#define GEMMA_API __attribute__((visibility("default")))
#endif

#ifdef __cplusplus
extern "C" {
#endif

#ifdef __cplusplus
typedef gcpp::GemmaContext GemmaContext;
#else
typedef struct GemmaContext GemmaContext;
#endif

typedef bool (*GemmaTokenCallback)(const char* text, void* user_data);
typedef void (*GemmaLogCallback)(const char* message, void* user_data);

GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path,
const char* model_type,
const char* weights_path,
const char* weight_type, int max_length);
GEMMA_API void GemmaDestroy(GemmaContext* ctx);
GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output,
int max_length, GemmaTokenCallback callback,
void* user_data);

GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text);

GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback,
void* user_data);

#ifdef __cplusplus
}
#endif

#endif // THIRD_PARTY_GEMMA_C_API_H_
Loading
Loading