brainscore_language.model_helpers.modeling_suma

PyTorch SUMA model adapted from LLaMA.

Functions

custom_init_weights(module, method[, variance])

repeat_kv(hidden_states, n_rep)

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).

Classes

BaseModelOutputWithPast([last_hidden_state, ...])

Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).

Cache()

Base, abstract class for all caches.

DynamicCache()

A cache that grows dynamically as more tokens are generated.

LlamaAttention(config[, layer_idx])

Multi-headed attention from 'Attention Is All You Need' paper

LlamaPreTrainedModel(config, *inputs, **kwargs)

The bare LLaMA Model outputting raw hidden-states without any specific head on top.

LlamaRMSNorm(hidden_size[, eps])

SUMAConfig([vocab_size, hidden_size, ...])

This is the configuration class to store the configuration of a [LlamaModel].

SUMADecoderLayer(config, layer_idx)

SUMAForCausalLM(config)

SUMAModel(config)

The bare LLaMA Model outputting raw hidden-states without any specific head on top.