Skip to content

Commit

Permalink
add top_p
Browse files Browse the repository at this point in the history
Signed-off-by: SumanthRH <[email protected]>
  • Loading branch information
SumanthRH committed Feb 9, 2025
1 parent d796ceb commit c5f7c88
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
21 changes: 19 additions & 2 deletions skythought/skythought_evals/inference_and_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def fetch_responses_ray(conversations, max_tokens, temp, args):
ds = ds.repartition(num_partitions=num_replicas)
workload = EvalWorkload(
dataset=ds,
sampling_params={"n": args.n, "max_tokens": max_tokens, "temperature": temp},
sampling_params={
"n": args.n,
"max_tokens": max_tokens,
"temperature": temp,
"top_p": args.top_p,
},
)
pipeline = Pipeline(
engine_cfg,
Expand Down Expand Up @@ -138,7 +143,7 @@ def inference(llm, conversations, max_tokens, temp, args):
responses = [Response.from_openai_response(response) for response in responses]
else:
sampling_params = SamplingParams(
max_tokens=max_tokens, temperature=temp, n=args.n
max_tokens=max_tokens, temperature=temp, n=args.n, top_p=args.top_p
)
responses = llm.chat(
messages=conversations, sampling_params=sampling_params, use_tqdm=True
Expand Down Expand Up @@ -629,6 +634,12 @@ def main():
"'auto' refers to automatically inferring dtype for the model",
default="float32",
)
parser.add_argument(
"--top_p",
type=float,
default=1,
help="Sampling parameter `top_p`",
)
args = parser.parse_args()
# load ray config
if args.use_ray:
Expand Down Expand Up @@ -660,6 +671,12 @@ def main():

temperatures = [1] if args.model.startswith("openai/o1") else args.temperatures

if args.top_p < 1 and args.model.startswith("openai/o1"):
print(
"OpenAI o1 models do not support `top_p` sampling. Resetting `top_p` to 1"
)
args.top_p = 1

print(f"Temperature: {temperatures}")
max_tokens = args.max_tokens
if temperatures == [0] and args.n > 1:
Expand Down
2 changes: 1 addition & 1 deletion skythought/skythought_evals/util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def _pass_at_k(n, c, k):
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
return float(1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))


def pass_at_k(N: int, temp_to_scores: Dict[str, Dict[str, Any]]):
Expand Down

0 comments on commit c5f7c88

Please sign in to comment.