Skip to content

Commit

Permalink
refactor to sample before tokenize (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Apr 2, 2024
1 parent 2245876 commit 5820c95
Showing 1 changed file with 62 additions and 72 deletions.
134 changes: 62 additions & 72 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,8 @@ def get_tokenizer(tokenizer_name: str) -> Any:

def load_sharegpt_dataset(
dataset_path: str,
tokenizer: Any,
conversation_starter: str,
max_output_length: Optional[int] = None,
) -> List[InputRequest]:
) -> List[tuple[str]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
Expand All @@ -150,116 +148,112 @@ def load_sharegpt_dataset(
for data in dataset
]

return dataset


def load_openorca_dataset(
dataset_path: str
) -> List[tuple[str]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)

# Tokenize the prompts and completions.
prompts = dataset["prompts"]
outputs = dataset["results"]

return [(prompt, output) for prompt, output in zip(prompts, outputs)]


def tokenize_dataset(
dataset: List[tuple[str]],
tokenizer: Any,
) -> List[tuple[Any]]:

n = len(dataset)

prompts = [prompt for prompt, _ in dataset]
outputs = [output for _, output in dataset]

prompt_token_ids = tokenizer.tokenize(
prompts
) # adjust this code based on tokenizer method
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer.tokenize(
completions
outputs_token_ids = tokenizer.tokenize(
outputs
) # adjust this code based on tokenizer method

tokenized_dataset = []
for i in range(len(dataset)):
for i in range(n):
prompt_len = len(prompt_token_ids[i])
completion_len = len(completion_token_ids[i])
output_len = len(outputs_token_ids[i])
tokenized_dataset.append(
(prompts[i], prompt_token_ids[i], completions[i], prompt_len, completion_len)
(prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len)
)
return tokenized_dataset

# Filter out too long sequences.
filtered_dataset: List[InputRequest] = []

for prompt, prompt_token_ids, completion, prompt_len, completion_len in tokenized_dataset:
if prompt_len < 4 or completion_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + completion_len > 2048:
# Prune too long sequences.
continue
request = InputRequest(prompt, prompt_len, completion, max_output_length or completion_len)
filtered_dataset.append(request)

def filter_dataset(
tokenized_dataset: List[tuple[Any]],
max_output_length: Optional[int] = None
) -> List[InputRequest]:
if max_output_length is None:
print("In InputRequest, pass in actual output_length for each sample")
else:
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")

print(f"The dataset contains {len(tokenized_dataset)} samples.")
print(f"The filtered dataset contains {len(filtered_dataset)} samples.")

return filtered_dataset


def load_openorca_dataset(
dataset_path: str,
tokenizer: Any,
max_output_length: Optional[int] = None,
) -> List[InputRequest]:

# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)

# Tokenize the prompts and completions.
prompts = dataset["prompts"]
outputs = dataset["results"]
n = len(prompts)
prompt_token_ids = tokenizer.tokenize(prompts)
output_token_ids = tokenizer.tokenize(outputs)

tokenized_dataset = []
for i in range(n):
prompt_len = len(prompt_token_ids[i])
output_len = len(output_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len))

# Filter out too long sequences.
filtered_dataset: List[InputRequest] = []

for prompt, prompt_token_ids, output, prompt_len, output_len in tokenized_dataset:
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
request = InputRequest(prompt, prompt_len, output, max_output_length or output_len)
filtered_dataset.append(request)

if max_output_length is None:
print("In InputRequest, pass in actual output_length for each sample")
else:
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")

print(f"The dataset contains {len(tokenized_dataset)} samples.")
print(f"The filtered dataset contains {len(filtered_dataset)} samples.")

return filtered_dataset


def sample_requests(
dataset: List[InputRequest],
dataset: List[tuple[str]],
tokenizer: Any,
num_requests: int,
max_output_length: Optional[int] = None,
oversample_multiplier: float=1.2,
) -> List[InputRequest]:

# Original dataset size
n = len(dataset)

# Create necessary number of requests even if bigger than dataset size
sampled_indices = random.sample(
range(len(dataset)), min(int(num_requests * oversample_multiplier), len(dataset)))
range(n), min(int(num_requests * oversample_multiplier), n))

if num_requests > len(sampled_indices):
print(f"Number of requests {num_requests} is larger than size of dataset {len(dataset)}.\n",
print(f"Number of requests {num_requests} is larger than size of dataset {n}.\n",
f"Repeating data to meet number of requests.\n")
sampled_indices = sampled_indices * int(np.ceil(num_requests / len(sampled_indices)))

print(f"{len(sampled_indices)=}")
# some of these will be filtered out, so sample more than we need
dataset = [dataset[i] for i in sampled_indices]

tokenized_dataset = tokenize_dataset(dataset, tokenizer)

input_requests = filter_dataset(tokenized_dataset, max_output_length)

# Sample the requests.
sampled_requests = random.sample(dataset, num_requests)
if len(input_requests) > num_requests:
input_requests = random.sample(input_requests, num_requests)

return sampled_requests
return input_requests


async def get_request(
Expand Down Expand Up @@ -490,25 +484,21 @@ def main(args: argparse.Namespace):
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
else:
if args.dataset == "openorca":
dataset = load_openorca_dataset(
args.dataset_path,
tokenizer,
args.max_output_length
)
dataset = load_openorca_dataset(args.dataset_path)
elif args.dataset == "sharegpt":
dataset = load_sharegpt_dataset(
args.dataset_path,
tokenizer,
args.conversation_starter,
args.max_output_length
)

# A given args.max_output_length value is the max generation step,
# when the args.max_output_length is default to None, the sample's golden output length
# will be used to decide the generation step
input_requests = sample_requests(
dataset,
args.num_prompts,
dataset=dataset,
tokenizer=tokenizer,
num_requests=args.num_prompts,
max_output_length=args.max_output_length
)

if args.warmup_first:
Expand Down

0 comments on commit 5820c95

Please sign in to comment.