Following is a little piece of code to extract embeddings from a certain layer of LLM:
def process_row(prompt: str, model, tokenizer, layers_to_use: list, remove_period: bool):
"""
Processes a row of data and returns the embeddings.
"""
if remove_period:
prompt = prompt.rstrip(". ")
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(inputs.input_ids, output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=1, min_new_tokens=1)
embeddings = {}
for layer in layers_to_use:
last_hidden_state = outputs.hidden_states[0][layer][0][-1]
embeddings[layer] = [last_hidden_state.numpy().tolist()]
return embeddings
It's pretty standard way, but it's pretty slow. Is there any way to use vllm to make it faster without needing to call generate function everytime? I've tried batching, but it's slow too. Any help is appreciated!
One way to get last hidden state values using vllm is as follows:
from vllm import LLM, SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceData,
SequenceGroupMetadata, SequenceStatus)
from transformers import LlamaModel, LlamaTokenizer
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
llm = LLM(model=path_to_llama2)
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = llm.llm_engine.workers[0].model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = llm.llm_engine.workers[0].scheduler_config.max_num_batched_tokens
max_num_seqs = llm.llm_engine.workers[0].scheduler_config.max_num_seqs
prompt = train[0]
prompt_token_ids = llm.llm_engine.tokenizer.encode(prompt) #[2, 100, 524, 10]
seqs = []
group_id = 1
seq_data = SequenceData(prompt_token_ids)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = llm.llm_engine.workers[0]._prepare_inputs(
seqs)
prompt_len = len(seq_data.prompt_token_ids)
input_tokens = input_tokens[:prompt_len]
input_positions = input_positions[:prompt_len]
# Execute the model.
num_layers = llm.llm_engine.workers[0].model_config.get_num_layers(llm.llm_engine.workers[0].parallel_config)
tempOut = llm.llm_engine.workers[0].model.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=[(None, None)] * num_layers,
input_metadata=input_metadata,
cache_events=None,
)
print(tempOut.size())
but this doesn't get me with all the hidden state embeddings (of all layers). Is there any other way to get such values in a faster manner?