-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_tweet_like_count_predictor.py
118 lines (94 loc) · 4.22 KB
/
train_tweet_like_count_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
from datetime import datetime
import pandas as pd
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from datasets import Dataset
# Disable tokenizers parallelism to avoid parallelism issues
os.environ["TOKENIZERS_PARALLELISM"] = "false"
PKL_FILE_NAME = 'like_predictor_data.pkl'
ID2LABEL = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5 : '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '10+'}
LABEL2ID = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '10+': 10}
model = model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-cased", num_labels=11, id2label=ID2LABEL, label2id=LABEL2ID
)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
def get_data():
# Check if embeddings.pkl file exists
if os.path.exists(PKL_FILE_NAME):
print('Loading embeddings from file...')
df = pd.read_pickle(PKL_FILE_NAME)
else:
print('Loading tweets and embedding...')
# Load your tweets.js file
with open('tweets.js', 'r') as f:
tweets = f.read()
tweets = tweets.replace('window.YTD.tweets.part0 = ', '')
# parse the remaining string as JSON
data = json.loads(tweets)
# Extract relevant fields and convert 'created_at' to timestamp
tweets_data = []
for tweet in data:
created_at = datetime.strptime(tweet['tweet']['created_at'], '%a %b %d %H:%M:%S +0000 %Y')
year = created_at.year
month = created_at.month
day = created_at.day
hour = created_at.hour
minute = created_at.minute
timestamp = created_at.timestamp(),
# cap the label at 10, 10 means 10 or more likes
label = 10 if int(tweet['tweet']['favorite_count']) > 10 else int(tweet['tweet']['favorite_count'])
tweets_data.append({
'full_text': tweet['tweet']['full_text'],
'timestamp': timestamp,
'lang': tweet['tweet']['lang'],
'like_count': tweet['tweet']['favorite_count'],
'year': year,
'month': month,
'day': day,
'hour': hour,
'minute': minute,
'labels': label
})
df = pd.DataFrame(tweets_data)
# Calculate and store the embeddings
print('Saving...', PKL_FILE_NAME)
df.to_pickle(PKL_FILE_NAME)
return df
if __name__ == '__main__':
data = get_data()
# texts
texts = data['full_text'].tolist()
encoded_inputs = tokenizer(texts)
# labels
labels = data['labels'].tolist()
ds = Dataset.from_dict({'input_ids': encoded_inputs['input_ids'], 'attention_mask': encoded_inputs['attention_mask'], 'labels': labels})
# print("ds[0]", ds[0])
# Split the dataset into a training and a validation dataset
train_dataset = ds.train_test_split(test_size=0.2, seed=42)['train']
val_dataset = ds.train_test_split(test_size=0.2, seed=42)['test']
print("train_dataset[0]", train_dataset[0])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Define the training arguments
training_args = TrainingArguments(
output_dir='./tweet-like-count-predictor-hf', # output directory
hub_model_id = 'gaborcselle/tweet-like-count-predictor', # model name
logging_dir='./logs', # directory for storing logs
push_to_hub=True
)
# Define the trainer parameters
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # evaluation dataset
data_collator=data_collator
)
# Train the model
trainer.train()
# Evaluate the model
trainer.evaluate()
# Save the model
trainer.save_model('./tweet-like-count-predictor-hf-model')
# upload the model to Hugging Face
#trainer.push_to_hub('gaborcselle/tweet_like_predictor')