Open
Description
I tried various models of transformers.js and those that support past_key_values does not actually handle it. I face several issues:
- The default past_key_values are gpuBuffer tensors and ONNX requires cpu tensors as input
- Downloading past_key_values into cpu using downloader method and running again will run into dimension inconsistency problems. Basically we need to feed input_ids, attention_mask, position_ids into the
model.generate()
, I tried various shapes and all failed:
- Assume the
past_key_value.dims[2] = past_length
and theinput_ids.dims[1] = full_length
. I tweaked all combinations of each input beingpast_length
orfull_length
orfull_length - past_length
or simply 1 (one token). None worked.
Please share a working example of transformers.js with past_key_values enabled.
Here is my code:
const full_inputs = tokenizer.apply_chat_template(messages, {
add_generation_prompt: true,
return_dict: true
});
for (const key in past_gpu_kv) {
if (past_gpu_kv[key]?.ort_tensor) {
past_kv[key] = await convertToCPUTensor(past_gpu_kv[key].ort_tensor);
}
}
const { past_key_values, sequences } = await model.generate({
...inputs,
past_key_values: past_kv,
use_cache: true,
do_sample: false,
top_k: 3,
temperature: 0.2,
max_new_tokens: 1024,
streamer,
stopping_criteria,
return_dict_in_generate: true,
});
async function convertToCPUTensor(ortTensor) {
if (!ortTensor || typeof ortTensor.downloader !== 'function') {
throw new Error('Invalid ort_tensor: missing downloader method');
}
// Download the data from GPU
const rawData = await ortTensor.downloader(); // usually a Float16Array or Float32Array
// Check the tensor type and convert to Float32Array if it's float16
let data = rawData;
let dtype = ortTensor.type;
if (dtype === 'float16') {
data = Float16Array.from(rawData); // Ensure data remains float16
dtype = 'float16';
}
return new Tensor(dtype, data, ortTensor.dims);
}
function buildInputsForGenerate(full_inputs, past_key_values_cache, modelKey) {
const input_ids_tensor = full_inputs.input_ids;
if (!past_key_values_cache[modelKey]) {
return full_inputs;
}
const seq_len = input_ids_tensor.dims[1];
if (seq_len === 0) {
throw new Error("input_ids is empty — can't slice last token.");
}
// Use past key dims to get cached length
const past = past_key_values_cache[modelKey];
const past_len = past['past_key_values.0.key'].dims[2];
const new_len = seq_len - past_len;
const input_ids = input_ids_tensor.slice([0, 1], [seq_len - 1, seq_len]);
const attention_mask_length = seq_len + 1;
const attention_mask = new Tensor(
"int64",
BigInt64Array.from([
//...Array(past_len).fill(BigInt(0)), // Mask out past tokens
...Array(attention_mask_length).fill(BigInt(1)), // Attend only to new tokens
]),
[1, attention_mask_length]
);
const position_ids = new Tensor(
"int64",
BigInt64Array.from([...Array(new_len).keys()].map(i => BigInt(past_len + i))),
[1, new_len]
);
return {
input_ids,
attention_mask,
position_ids,
};
}
Metadata
Metadata
Assignees
Labels
No labels