diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index 71617daf..7636afa8 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -16,6 +16,16 @@ pub struct Tokenization { sender: async_channel::Sender, } +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct SimpleToken { + pub id: u32, + pub text: String, + pub special: bool, + pub start: Option, + pub stop: Option, +} + impl Tokenization { pub fn new( workers: usize, @@ -485,3 +495,155 @@ enum TokenizerRequest { Span, ), } + +pub fn into_tokens(encoding: tokenizers::Encoding, input: &str) -> Vec { + encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .zip(encoding.get_special_tokens_mask()) + .zip(encoding.get_tokens()) + .map(|(((&id, &(start, stop)), special), token)| { + let special = *special == 1; + match special { + true => SimpleToken { + id, + text: token.clone(), + special, + start: None, + stop: None, + }, + false => { + let text: Vec = input.bytes().skip(start).take(stop - start).collect(); + let text: String = String::from_utf8_lossy(&text).to_string(); + SimpleToken { + id, + text, + special, + start: Some(start), + stop: Some(stop), + } + } + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use hf_hub::api::sync::ApiBuilder; + + #[test] + fn tokenizer() { + let api = ApiBuilder::from_env().build().unwrap(); + let filename = api + .model("BAAI/bge-m3".to_string()) + .get("tokenizer.json") + .unwrap(); + let string = "这是一个文本向量化的测试句子"; + let tokenizer = Tokenizer::from_file(filename).unwrap(); + + let encoded = tokenizer.encode(string, true).unwrap(); + assert_eq!( + encoded.get_offsets(), + vec![ + (0, 0), + (0, 3), + (0, 12), + (12, 18), + (18, 21), + (21, 24), + (24, 30), + (30, 36), + (36, 39), + (39, 42), + (0, 0) + ] + ); + + let tokens = into_tokens(encoded, &string); + assert_eq!( + tokens, + vec![ + SimpleToken { + id: 0, + text: "".to_string(), + special: true, + start: None, + stop: None + }, + SimpleToken { + id: 6, + text: "这".to_string(), + special: false, + start: Some(0), + stop: Some(3) + }, + SimpleToken { + id: 100013, + text: "这是一个".to_string(), + special: false, + start: Some(0), + stop: Some(12) + }, + SimpleToken { + id: 189061, + text: "文本".to_string(), + special: false, + start: Some(12), + stop: Some(18) + }, + SimpleToken { + id: 2110, + text: "向".to_string(), + special: false, + start: Some(18), + stop: Some(21) + }, + SimpleToken { + id: 3272, + text: "量".to_string(), + special: false, + start: Some(21), + stop: Some(24) + }, + SimpleToken { + id: 41904, + text: "化的".to_string(), + special: false, + start: Some(24), + stop: Some(30) + }, + SimpleToken { + id: 49125, + text: "测试".to_string(), + special: false, + start: Some(30), + stop: Some(36) + }, + SimpleToken { + id: 27683, + text: "句".to_string(), + special: false, + start: Some(36), + stop: Some(39) + }, + SimpleToken { + id: 1344, + text: "子".to_string(), + special: false, + start: Some(39), + stop: Some(42) + }, + SimpleToken { + id: 2, + text: "".to_string(), + special: true, + start: None, + stop: None + } + ] + ); + } +} diff --git a/router/src/grpc/server.rs b/router/src/grpc/server.rs index 8de706dd..3c98f8b8 100644 --- a/router/src/grpc/server.rs +++ b/router/src/grpc/server.rs @@ -15,7 +15,9 @@ use std::future::Future; use std::net::SocketAddr; use std::time::{Duration, Instant}; use text_embeddings_core::infer::Infer; -use text_embeddings_core::tokenization::EncodingInput; +use text_embeddings_core::tokenization::{ + into_tokens, EncodingInput, SimpleToken as CoreSimpleToken, +}; use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::StreamExt; @@ -340,32 +342,22 @@ impl TextEmbeddingsService { .map_err(ErrorResponse::from)?; let inputs = encoded_inputs.unwrap_or(inputs); - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .zip(encoding.get_special_tokens_mask()) - .zip(encoding.get_tokens()) - .map(|(((&id, &(start, stop)), special), token)| { - let special = *special == 1; - match special { - true => SimpleToken { - id, - text: token.clone(), - special, - start: None, - stop: None, - }, - false => { - let text: String = inputs.chars().skip(start).take(stop - start).collect(); - SimpleToken { - id, - text, - special, - start: Some(start as u32), - stop: Some(stop as u32), - } - } + let tokens: Vec = into_tokens(encoding, &inputs) + .into_iter() + .map(|t| { + let CoreSimpleToken { + id, + text, + special, + start, + stop, + } = t; + SimpleToken { + id, + text, + special, + start: start.map(|s| s as u32), + stop: stop.map(|s| s as u32), } }) .collect(); diff --git a/router/src/http/server.rs b/router/src/http/server.rs index cadb6c18..2abd81e9 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs @@ -34,6 +34,7 @@ use text_embeddings_backend::BackendError; use text_embeddings_core::infer::{ AllEmbeddingsInferResponse, Infer, InferMetadata, PooledEmbeddingsInferResponse, }; +use text_embeddings_core::tokenization::{into_tokens, SimpleToken as CoreSimpleToken}; use text_embeddings_core::TextEmbeddingsError; use tokio::sync::OwnedSemaphorePermit; use tower_http::cors::{AllowOrigin, CorsLayer}; @@ -1295,32 +1296,22 @@ async fn tokenize( .map_err(ErrorResponse::from)?; let input = encoded_input.unwrap_or(input); - let tokens: Vec = encoding - .get_ids() - .iter() - .zip(encoding.get_offsets()) - .zip(encoding.get_special_tokens_mask()) - .zip(encoding.get_tokens()) - .map(|(((&id, &(start, stop)), special), token)| { - let special = *special == 1; - match special { - true => SimpleToken { - id, - text: token.clone(), - special, - start: None, - stop: None, - }, - false => { - let text: String = input.chars().skip(start).take(stop - start).collect(); - SimpleToken { - id, - text, - special, - start: Some(start), - stop: Some(stop), - } - } + let tokens: Vec = into_tokens(encoding, &input) + .into_iter() + .map(|t| { + let CoreSimpleToken { + id, + text, + special, + start, + stop, + } = t; + SimpleToken { + id, + text, + special, + start, + stop, } }) .collect();