-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcausal_lm_flops_counter.py
49 lines (32 loc) · 1.62 KB
/
causal_lm_flops_counter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from fvcore.nn import FlopCountAnalysis
from transformers import AutoModelForCausalLM
import argparse
import torch
import pprint as pp
pp = pp.PrettyPrinter(indent=4)
def get_input_shapes(batch_size=1, sequence_length=512):
# Create a dummy input tensor with the desired shape
input_tensor = torch.ones(batch_size, sequence_length, dtype = torch.long)
return input_tensor
def parse_args():
parser = argparse.ArgumentParser(description="Count FLOPs of a causal language model.")
parser.add_argument("--model_stub", default="HuggingFaceH4/tiny-random-LlamaForCausalLM", type=str, help="The model stub to count FLOPs for.")
parser.add_argument("--batch_size", default=1, type=int, help="The batch size to use for the input tensor.")
parser.add_argument("--sequence_length", default=512, type=int, help="The sequence length to use for the input tensor.")
parser.add_argument("--by_module", default=False, action="store_true", help="Count FLOPs by module.")
parser.add_argument("--by_operator", default=False, action="store_true", help="Count FLOPs by operator.")
return parser.parse_args()
def main():
args = parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_stub)
input_shape = get_input_shapes(args.batch_size, args.sequence_length)
flops = FlopCountAnalysis(model, input_shape)
print(f"FLOP/s:")
if args.by_module:
pp.pprint(flops.by_module())
elif args.by_operator:
pp.pprint(flops.by_operator())
else:
pp.pprint(flops.by_module_and_operator())
if __name__ == "__main__":
main()