Skip to content

[Feature] ISS-60: Implement Self Extend #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
9 changes: 9 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ struct LayerConfig {
size_t conv1d_width = 0; // griffin only
bool ff_biases = false;
bool softmax_attn_output_biases = false;
/**
* Self-extend
* Jin, Hongye, et al. "Llm maybe longlm: Self-extend llm context window without tuning." arXiv preprint arXiv:2401.01325 (2024).
*/
bool self_extend = false;
// Self-extend neighbor size
size_t se_neighbor_size = std::numeric_limits<size_t>::max();
// Self-extend group window size
size_t se_group_size = 1;
bool optimized_gating = true;
PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
Expand Down
32 changes: 28 additions & 4 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,28 +300,39 @@ class GemmaAttention {
}
} // !is_mha_

// Self-extension
const hwy::Divisor div_grp_size(
static_cast<uint32_t>(layer_config_.se_group_size));
// Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(0, kv_heads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t pos = queries_pos_[query_idx] + batch_idx;
size_t pos = queries_pos_[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len_.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size_ +
layer_ * cache_layer_size_ +
head * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];

const size_t se_neighbor_size = layer_config_.se_neighbor_size;
const bool enable_self_extend = layer_config_.self_extend;

float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
qkv_dim;

// In self-extend, when embedding position,
// we will use grouped key position
if (enable_self_extend && pos > se_neighbor_size) {
pos = div_grp_size.Divide(pos);
}
// Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv);

// If MHA, also copy V into KVCache.
if (is_mha_) {
hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
Expand Down Expand Up @@ -405,12 +416,25 @@ class GemmaAttention {
const size_t batch_idx = interleaved_idx / num_queries_;
const size_t head_offset =
(head / kHeadGroups) * layer_config_.qkv_dim * 2;

const size_t se_group_size = layer_config_.se_group_size;
const size_t se_neighbor_size = layer_config_.se_neighbor_size;
const bool enable_self_extend =
layer_config_.self_extend;

KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;

// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
size_t pos = queries_pos_[query_idx] + batch_idx;
if (enable_self_extend && pos > se_neighbor_size) {
const size_t grp_pos = pos / se_group_size;
const size_t shift =
se_neighbor_size - se_neighbor_size / se_group_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
PositionalEncodingQK(q, pos, layer_, query_scale, q);

const size_t start_pos = StartPos(pos, layer_);
Expand Down Expand Up @@ -1408,7 +1432,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<T>(model, activations, runtime_config, qbatch_prompts, qbatch_pos,
qbatch_prefix_end, qbatch_start, qbatch_kv, timing_info);
Expand Down
1 change: 1 addition & 0 deletions gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class Gemma {
~Gemma();

const ModelConfig& GetModelConfig() const { return model_.Config(); }
ModelConfig& GetMutableModelConfig() { return model_.MutableConfig(); }
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ModelWeightsStorage& Weights() const { return model_; }
Expand Down
21 changes: 21 additions & 0 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ std::string GetPrompt(std::istream& input, int verbosity,
return prompt_string;
}

// Extract args from the loader and modify model config
void ApplySelfExtendIfGiven(Gemma& model, LoaderArgs loader) {
ModelConfig& config = model.GetMutableModelConfig();
if (loader.self_extend != Tristate::kTrue) {
return;
}

// Modify layer config in-place
auto& layer_configs = config.layer_configs;
std::transform(layer_configs.begin(), layer_configs.end(), layer_configs.begin(),
[&loader](LayerConfig& layer_config) {
layer_config.self_extend =
loader.self_extend == Tristate::kTrue;
layer_config.se_group_size = loader.se_group_size;
layer_config.se_neighbor_size = loader.se_neighbor_size;

return layer_config;
});
}

// The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
const InferenceArgs& args, const AcceptFunc& accept_token,
Expand Down Expand Up @@ -206,6 +226,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
Allocator::Init(pools.Topology());

Gemma model = CreateGemma(loader, pools);
ApplySelfExtendIfGiven(model, loader);
KVCache kv_cache =
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);

Expand Down
1 change: 1 addition & 0 deletions gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ class ModelWeightsStorage {
void CopyWithTranspose(hwy::ThreadPool& pool);
void LogWeightStats();
const ModelConfig& Config() const { return config_; }
ModelConfig& MutableConfig() { return config_; }

template <typename T>
ModelWeightsPtrs<T>* GetWeightsOfType() const {
Expand Down
11 changes: 11 additions & 0 deletions util/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
std::string model_type_str;
std::string weight_type_str;

// Self-extend
Tristate self_extend;
size_t se_group_size;
size_t se_neighbor_size;

template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(tokenizer, "tokenizer", Path(),
Expand All @@ -189,6 +194,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n"
" Required argument.");
visitor(self_extend, "self_extend", Tristate::kDefault,
"Apply self extend ? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(se_group_size, "se_group_size", size_t{1}, "Group size for self extend");
visitor(se_neighbor_size, "se_neighbor_size",
std::numeric_limits<size_t>::max(),
"Neighbor window size for self extend");
}

// Uninitialized before Validate, must call after that.
Expand Down