diff --git a/BUILD.bazel b/BUILD.bazel index a3c956e2..cd53d7e3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -356,6 +356,44 @@ cc_library( ], ) +cc_library( + name = "gemma_shared_lib", + srcs = [ + "gemma/c_api.cc", + "gemma/context.cc", + ], + hdrs = [ + "gemma/activations.h", + "gemma/c_api.h", + "gemma/context.h", + "gemma/gemma.h", + ], + exec_properties = { + # Avoid linker OOMs when building with sanitizer instrumentation. + "mem": "28g", + }, + deps = [ + ":allocator", + ":app", + ":basics", + ":common", + ":kv_cache", + ":ops", + ":threading", + ":tokenizer", + ":weights", + "//compression:compress", + "//compression:io", + "//compression:sfp", + "//paligemma:image", + "@highway//:bit_set", + "@highway//:hwy", + "@highway//:nanobenchmark", # timer + "@highway//:profiler", + "@highway//:thread_pool", + ], +) + cc_library( name = "cross_entropy", srcs = ["evals/cross_entropy.cc"], diff --git a/CMakeLists.txt b/CMakeLists.txt index 46bac38f..cc6135d0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,17 @@ set(SOURCES util/threading.h ) +# Add C API sources only when building DLL +if(BUILD_GEMMA_DLL) + list(APPEND SOURCES + gemma/context.h + gemma/context.cc + gemma/c_api.h + gemma/c_api.cc + ) + message(STATUS "Including C API files for DLL build") +endif() + if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() @@ -129,6 +140,33 @@ target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) install(TARGETS libgemma DESTINATION lib) +# Shared Library Target for C# interop +if(BUILD_GEMMA_DLL) + add_library(gemma_shared SHARED ${SOURCES}) +set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17) +set_target_properties(gemma_shared PROPERTIES + PREFIX "" + OUTPUT_NAME "gemma" +) +set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON) +target_include_directories(gemma_shared PUBLIC ./) +target_link_libraries(gemma_shared PRIVATE + $ + $ + $ +) +target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(gemma_shared + PRIVATE + GEMMA_EXPORTS + $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX> +) +target_compile_options(gemma_shared PRIVATE $<$:-Wno-deprecated-declarations>) +install(TARGETS gemma_shared DESTINATION lib) +install(FILES gemma/c_api.h DESTINATION include/gemma) +install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma) +endif() + # Executable Target add_executable(gemma gemma/run.cc) diff --git a/GemmaInterop.cs b/GemmaInterop.cs new file mode 100644 index 00000000..842f66c2 --- /dev/null +++ b/GemmaInterop.cs @@ -0,0 +1,175 @@ +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text; +namespace GemmaCpp +{ + public class GemmaException : Exception + { + public GemmaException(string message) : base(message) { } + } + + public class Gemma : IDisposable + { + private IntPtr _context; + private bool _disposed; + + // Optional: Allow setting DLL path + public static string DllPath { get; set; } = "gemma.dll"; + + [DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern IntPtr LoadLibrary(string lpFileName); + + static Gemma() + { + // Load DLL from specified path + if (LoadLibrary(DllPath) == IntPtr.Zero) + { + throw new DllNotFoundException($"Failed to load {DllPath}. Error: {Marshal.GetLastWin32Error()}"); + } + } + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern IntPtr GemmaCreate( + [MarshalAs(UnmanagedType.LPUTF8Str)] string tokenizerPath, + [MarshalAs(UnmanagedType.LPUTF8Str)] string modelType, + [MarshalAs(UnmanagedType.LPUTF8Str)] string weightsPath, + [MarshalAs(UnmanagedType.LPUTF8Str)] string weightType, + int maxLength); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaDestroy(IntPtr context); + + // Delegate type for token callbacks + public delegate bool TokenCallback(string token); + + // Keep delegate alive for duration of calls + private GCHandle _callbackHandle; + + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private delegate bool GemmaTokenCallback( + [MarshalAs(UnmanagedType.LPUTF8Str)] string text, + IntPtr userData); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern int GemmaGenerate( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string prompt, + [MarshalAs(UnmanagedType.LPUTF8Str)] StringBuilder output, + int maxLength, + GemmaTokenCallback callback, + IntPtr userData); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern int GemmaCountTokens( + IntPtr context, + [MarshalAs(UnmanagedType.LPUTF8Str)] string text); + + // Native callback delegate type + [UnmanagedFunctionPointer(CallingConvention.Cdecl)] + private delegate void GemmaLogCallback( + [MarshalAs(UnmanagedType.LPUTF8Str)] string message, + IntPtr userData); + + [DllImport("gemma", CallingConvention = CallingConvention.Cdecl)] + private static extern void GemmaSetLogCallback( + IntPtr context, + GemmaLogCallback callback, + IntPtr userData); + + private GCHandle _logCallbackHandle; + + public Gemma(string tokenizerPath, string modelType, string weightsPath, string weightType, int maxLength = 8192) + { + _context = GemmaCreate(tokenizerPath, modelType, weightsPath, weightType, maxLength); + if (_context == IntPtr.Zero) + { + throw new GemmaException("Failed to create Gemma context"); + } + + // optionally: set up logging + /* + GemmaLogCallback logCallback = (message, _) => + { +#if UNITY_ENGINE + Debug.Log($"Gemma: {message}"); +#else + Debug.WriteLine($"Gemma: {message}"); +#endif + }; + _logCallbackHandle = GCHandle.Alloc(logCallback); + GemmaSetLogCallback(_context, logCallback, IntPtr.Zero); + */ + } + + public int CountTokens(string prompt) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + int count = GemmaCountTokens(_context, prompt); + return count; + } + + public string Generate(string prompt, int maxLength = 4096) + { + return Generate(prompt, null, maxLength); + } + + public string Generate(string prompt, TokenCallback callback, int maxLength = 4096) + { + if (_disposed) + throw new ObjectDisposedException(nameof(Gemma)); + + if (_context == IntPtr.Zero) + throw new GemmaException("Gemma context is invalid"); + + var output = new StringBuilder(maxLength); + GemmaTokenCallback nativeCallback = null; + + if (callback != null) + { + nativeCallback = (text, _) => callback(text); + _callbackHandle = GCHandle.Alloc(nativeCallback); + } + + try + { + int length = GemmaGenerate(_context, prompt, output, maxLength, + nativeCallback, IntPtr.Zero); + + if (length < 0) + throw new GemmaException("Generation failed"); + + return output.ToString(); + } + finally + { + if (_callbackHandle.IsAllocated) + _callbackHandle.Free(); + } + } + + public void Dispose() + { + if (!_disposed) + { + if (_context != IntPtr.Zero) + { + GemmaDestroy(_context); + _context = IntPtr.Zero; + } + if (_logCallbackHandle.IsAllocated) + _logCallbackHandle.Free(); + _disposed = true; + } + } + + ~Gemma() + { + Dispose(); + } + } +} \ No newline at end of file diff --git a/gemma/c_api.cc b/gemma/c_api.cc new file mode 100644 index 00000000..18454158 --- /dev/null +++ b/gemma/c_api.cc @@ -0,0 +1,54 @@ +#ifndef GEMMA_EXPORTS +#define GEMMA_EXPORTS +#endif + +#include "gemma/c_api.h" + +// necessary as the C API and GemmaContext effectively wrap up and re-use the +// code for the Gemma executable +#include "util/app.h" + +extern "C" { + +GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, + const char* model_type, + const char* weights_path, + const char* weight_type, int max_length) { + try { + // kludge + gcpp::AppArgs app_args; + app_args.Init(); + app_args.max_packages = 1; + app_args.verbosity = 0; + app_args.spin = gcpp::Tristate::kFalse; + + return new GemmaContext(tokenizer_path, model_type, weights_path, + weight_type, app_args, max_length); + } catch (...) { + return nullptr; + } +} + +GEMMA_API void GemmaDestroy(GemmaContext* ctx) { + delete static_cast(ctx); +} + +GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, + int max_length, GemmaTokenCallback callback, + void* user_data) { + if (!ctx) return -1; + return static_cast(ctx)->Generate( + prompt, output, max_length, callback, user_data); +} + +GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text) { + if (!ctx || !text) return -1; + return static_cast(ctx)->CountTokens(text); +} + +GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback, + void* user_data) { + if (!ctx) return; + static_cast(ctx)->SetLogCallback(callback, user_data); +} +} \ No newline at end of file diff --git a/gemma/c_api.h b/gemma/c_api.h new file mode 100644 index 00000000..0dc23b12 --- /dev/null +++ b/gemma/c_api.h @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_C_API_H_ +#define THIRD_PARTY_GEMMA_C_API_H_ + +#include "gemma/context.h" + +#ifdef _WIN32 +#ifdef GEMMA_EXPORTS +#define GEMMA_API __declspec(dllexport) +#else +#define GEMMA_API __declspec(dllimport) +#endif +#else +#define GEMMA_API __attribute__((visibility("default"))) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +typedef gcpp::GemmaContext GemmaContext; +#else +typedef struct GemmaContext GemmaContext; +#endif + +typedef bool (*GemmaTokenCallback)(const char* text, void* user_data); +typedef void (*GemmaLogCallback)(const char* message, void* user_data); + +GEMMA_API GemmaContext* GemmaCreate(const char* tokenizer_path, + const char* model_type, + const char* weights_path, + const char* weight_type, int max_length); +GEMMA_API void GemmaDestroy(GemmaContext* ctx); +GEMMA_API int GemmaGenerate(GemmaContext* ctx, const char* prompt, char* output, + int max_length, GemmaTokenCallback callback, + void* user_data); + +GEMMA_API int GemmaCountTokens(GemmaContext* ctx, const char* text); + +GEMMA_API void GemmaSetLogCallback(GemmaContext* ctx, GemmaLogCallback callback, + void* user_data); + +#ifdef __cplusplus +} +#endif + +#endif // THIRD_PARTY_GEMMA_C_API_H_ \ No newline at end of file diff --git a/gemma/context.cc b/gemma/context.cc new file mode 100644 index 00000000..650f9bd5 --- /dev/null +++ b/gemma/context.cc @@ -0,0 +1,130 @@ +#include "gemma/context.h" + +namespace gcpp { + +void InitializeGemmaLibrary() { + AppArgs app; + app.Init(); + app.max_packages = 1; + NestedPools pools = CreatePools(app); + Allocator::Init(pools.Topology()); +} + +// Initialize static members +GemmaLogCallback GemmaContext::s_log_callback = nullptr; +void* GemmaContext::s_log_user_data = nullptr; + +GemmaContext::GemmaContext(const char* tokenizer_path, const char* model_type, + const char* weights_path, const char* weight_type, + const AppArgs& app_args, int max_length) + : pools(CreatePools(app_args)) { + LoaderArgs loader(tokenizer_path, weights_path, model_type); + loader.weight_type_str = weight_type; + + if (const char* error = loader.Validate()) { + HWY_ABORT("Invalid loader configuration: %s", error); + } + + // Initialize cached args + inference_args.Init(); + inference_args.max_generated_tokens = max_length; + inference_args.temperature = 0.7f; + inference_args.top_k = 1; + inference_args.deterministic = false; + + Allocator::Init(pools.Topology()); + model = AllocateGemma(loader, pools); + kv_cache = + std::make_unique(KVCache::Create(model->GetModelConfig(), 2048)); +} + +int GemmaContext::Generate(const char* prompt, char* output, int max_length, + GemmaTokenCallback callback, void* user_data) { + if (!model || !kv_cache || !prompt || !output || max_length <= 0) { + return -1; + } + + try { + // Clear and reuse buffers + result_buffer.clear(); + prompt_buffer.assign(prompt); + token_buffer.clear(); + + // The prompt is assumed to be already wrapped in the appropriate control + // tokens if necessary for an instruction tuned model, so we don't use + // WrapAndTokenize here + HWY_ASSERT(model->Tokenizer().Encode(prompt, &token_buffer)); + + // Both pre-trained and instruction-tuned require BOS as first token + if (token_buffer.at(0) != BOS_ID) { + token_buffer.insert(token_buffer.begin(), BOS_ID); + } + + // Pass prompt_tokens to properly utilize KV cache for subsequent tokens + const size_t prompt_tokens = token_buffer.size(); + size_t tokens_generated_this_turn = 0; + + auto stream_token = [this, callback, user_data, prompt_tokens, + &tokens_generated_this_turn](int token, float) { + std::string token_text; + if (model->Tokenizer().Decode(std::vector{token}, &token_text)) { + // don't re-output the prompt tokens + if (tokens_generated_this_turn < prompt_tokens) { + ++tokens_generated_this_turn; + return true; + } + // skip the end of turn token, this way we don't have to do string + // comparisons at the application level (is this a good idea?) + if (token == END_OF_TURN_ID) { + return false; + } + + if (callback) { + if (!callback(token_text.c_str(), user_data)) { + return false; + } + } + result_buffer.append(token_text); + ++tokens_generated_this_turn; + return true; + } + return false; + }; + + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = 0, + .stream_token = stream_token, + .use_spinning = Tristate::kFalse}; + inference_args.max_generated_tokens = max_length; + inference_args.CopyTo(runtime_config); + + TimingInfo timing_info = {.verbosity = 0}; + hwy::Span testspan(token_buffer.data(), token_buffer.size()); + + // Pass prompt_tokens to properly utilize KV cache for subsequent tokens + model->Generate(runtime_config, testspan, prompt_tokens, 0, *kv_cache, + timing_info); + + if (result_buffer.length() >= static_cast(max_length)) { + return -1; + } + strcpy(output, result_buffer.c_str()); + return static_cast(result_buffer.length()); + } catch (...) { + return -1; + } +} + +int GemmaContext::CountTokens(const char* text) { + if (!model || !text) return -1; + try { + std::string text_str(text); + std::vector tokens; + HWY_ASSERT(model->Tokenizer().Encode(text_str, &tokens)); + return static_cast(tokens.size()); + } catch (...) { + return -1; + } +} + +} // namespace gcpp \ No newline at end of file diff --git a/gemma/context.h b/gemma/context.h new file mode 100644 index 00000000..e56d3104 --- /dev/null +++ b/gemma/context.h @@ -0,0 +1,92 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#endif + +#include "gemma/gemma.h" +#include "util/app.h" +#include "util/threading.h" + +namespace gcpp { + +// Initialize global state needed by the library. +// Must be called before creating any Gemma instances. +void InitializeGemmaLibrary(); + +typedef bool (*GemmaTokenCallback)(const char* text, void* user_data); +typedef void (*GemmaLogCallback)(const char* message, void* user_data); + +class GemmaContext { + public: + GemmaContext(const char* tokenizer_path, const char* model_type, + const char* weights_path, const char* weight_type, + const AppArgs& app_args, int max_length = 2048); + + // Returns length of generated text, or -1 on error + int Generate(const char* prompt, char* output, int max_length, + GemmaTokenCallback callback, void* user_data); + + // Returns number of tokens in text, or -1 on error + int CountTokens(const char* text); + + // Add new method to set logger + static void SetLogCallback(GemmaLogCallback callback, void* user_data) { + s_log_callback = callback; + s_log_user_data = user_data; + } + + private: + NestedPools pools; + std::unique_ptr model; + std::unique_ptr kv_cache; + std::string prompt_buffer; + std::string result_buffer; + std::vector token_buffer; + + // Cached args + InferenceArgs inference_args; + AppArgs app_args; + std::mt19937 gen; + + // Add static members for logging + static GemmaLogCallback s_log_callback; + static void* s_log_user_data; + + // Use logging helper method to print messages into a managed callback if + // necessary + static void LogDebug(const char* message) { + if (s_log_callback) { + s_log_callback(message, s_log_user_data); + } else { +#ifdef _WIN32 + OutputDebugStringA(message); +#endif + } + } +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONTEXT_H_ \ No newline at end of file diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index e2bb6115..2ed4de09 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -31,6 +31,9 @@ namespace gcpp { constexpr int EOS_ID = 1; constexpr int BOS_ID = 2; +// The tokenizer's end of turn token id. +constexpr int END_OF_TURN_ID = 107; + class GemmaTokenizer { public: GemmaTokenizer();