Skip to content

FAIL: Using cache and enabling back past_key_values cache #58

Open
@msedalatzadeh

Description

@msedalatzadeh

I tried various models of transformers.js and those that support past_key_values does not actually handle it. I face several issues:

  1. The default past_key_values are gpuBuffer tensors and ONNX requires cpu tensors as input
  2. Downloading past_key_values into cpu using downloader method and running again will run into dimension inconsistency problems. Basically we need to feed input_ids, attention_mask, position_ids into the model.generate(), I tried various shapes and all failed:
  • Assume the past_key_value.dims[2] = past_length and the input_ids.dims[1] = full_length. I tweaked all combinations of each input being past_length or full_length or full_length - past_length or simply 1 (one token). None worked.

Please share a working example of transformers.js with past_key_values enabled.

Here is my code:

const full_inputs = tokenizer.apply_chat_template(messages, {
  add_generation_prompt: true,
  return_dict: true
});


for (const key in past_gpu_kv) {
  if (past_gpu_kv[key]?.ort_tensor) {
    past_kv[key] = await convertToCPUTensor(past_gpu_kv[key].ort_tensor);
  }
}

const { past_key_values, sequences } = await model.generate({
  ...inputs,
  past_key_values: past_kv,
  use_cache: true,
  do_sample: false,
  top_k: 3,
  temperature: 0.2,
  max_new_tokens: 1024,
  streamer,
  stopping_criteria,
  return_dict_in_generate: true,
});
async function convertToCPUTensor(ortTensor) {
  if (!ortTensor || typeof ortTensor.downloader !== 'function') {
    throw new Error('Invalid ort_tensor: missing downloader method');
  }

  // Download the data from GPU
  const rawData = await ortTensor.downloader(); // usually a Float16Array or Float32Array

  // Check the tensor type and convert to Float32Array if it's float16
  let data = rawData;
  let dtype = ortTensor.type;

  if (dtype === 'float16') {
    data = Float16Array.from(rawData); // Ensure data remains float16
    dtype = 'float16';
  }

  return new Tensor(dtype, data, ortTensor.dims);
}
function buildInputsForGenerate(full_inputs, past_key_values_cache, modelKey) {
  const input_ids_tensor = full_inputs.input_ids;

  if (!past_key_values_cache[modelKey]) {
    return full_inputs;
  }

  const seq_len = input_ids_tensor.dims[1];
  if (seq_len === 0) {
    throw new Error("input_ids is empty — can't slice last token.");
  }
  // Use past key dims to get cached length
  const past = past_key_values_cache[modelKey];
  const past_len = past['past_key_values.0.key'].dims[2];
  const new_len = seq_len - past_len;

  const input_ids = input_ids_tensor.slice([0, 1], [seq_len - 1, seq_len]);
  
  const attention_mask_length = seq_len + 1;
  const attention_mask = new Tensor(
    "int64",
    BigInt64Array.from([
      //...Array(past_len).fill(BigInt(0)),       // Mask out past tokens
      ...Array(attention_mask_length).fill(BigInt(1)),        // Attend only to new tokens
    ]),
    [1, attention_mask_length]
  );

  const position_ids = new Tensor(
    "int64",
    BigInt64Array.from([...Array(new_len).keys()].map(i => BigInt(past_len + i))),
    [1, new_len]
  );

  return {
    input_ids,
    attention_mask,
    position_ids,
  };
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions