-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample.py
167 lines (141 loc) · 6.24 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import json
import random
import torch
from time import sleep
from transformers import MarkupLMProcessor, MarkupLMForQuestionAnswering
from playwright.sync_api import sync_playwright
import os
class MarkupLMExtractor:
def __init__(self, model_path):
print("🔄 Loading model and processor...")
# Verify model path exists
if not os.path.exists(model_path):
raise ValueError(f"Model path {model_path} does not exist!")
# Load processor and model
try:
self.processor = MarkupLMProcessor.from_pretrained(model_path)
self.processor.parse_html = True
self.model = MarkupLMForQuestionAnswering.from_pretrained(model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"📱 Using device: {self.device}")
self.model.to(self.device)
self.model.eval()
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
def fetch_html(self, url):
"""
Fetch HTML content using Playwright with better anti-bot handling.
"""
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
context = browser.new_context(
# Add realistic browser configuration
viewport={'width': 1920, 'height': 1080},
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/121.0.0.0 Safari/537.36',
locale='en-US',
timezone_id='America/New_York',
geolocation={'latitude': 40.7128, 'longitude': -74.0060},
permissions=['geolocation']
)
# Add common cookies
context.add_cookies([
{'name': 'session-id', 'value': '123-1234567-1234567', 'domain': '.amazon.com', 'path': '/'},
{'name': 'i18n-prefs', 'value': 'USD', 'domain': '.amazon.com', 'path': '/'},
{'name': 'sp-cdn', 'value': 'L5Z9:US', 'domain': '.amazon.com', 'path': '/'}
])
page = context.new_page()
# Set realistic headers
page.set_extra_http_headers({
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8',
'Accept-Language': 'en-US,en;q=0.9',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'keep-alive',
'Upgrade-Insecure-Requests': '1',
'Sec-Fetch-Dest': 'document',
'Sec-Fetch-Mode': 'navigate',
'Sec-Fetch-Site': 'none',
'Sec-Fetch-User': '?1'
})
try:
print(f"\n🌐 Fetching URL: {url}")
page.goto(url, timeout=90000, wait_until='domcontentloaded')
# Simulate human scrolling behavior
for _ in range(3):
page.evaluate("window.scrollBy(0, window.innerHeight)")
page.wait_for_timeout(random.randint(800, 1200))
page.evaluate("window.scrollTo(0, 0)")
page.wait_for_timeout(random.randint(800, 1200))
html = page.content()
# Check for CAPTCHA
if "Robot Check" in html or "Enter the characters you see below" in html:
print("⚠️ CAPTCHA detected!")
return None
return html
except Exception as e:
print(f"❌ Error fetching URL: {str(e)}")
return None
finally:
browser.close()
def process_url(self, url, query_text):
html_content = self.fetch_html(url)
if html_content is None:
return None
print("\n🔍 Processing query with MarkupLM model...")
encoding = self.processor(
html_strings=html_content,
questions=query_text,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
encoding = {k: v.to(self.device) for k, v in encoding.items()}
with torch.no_grad():
outputs = self.model(**encoding)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
start_idx = torch.argmax(start_logits).item()
end_idx = torch.argmax(end_logits).item()
if end_idx < start_idx:
start_idx, end_idx = end_idx, start_idx
input_ids = encoding["input_ids"][0]
answer_ids = input_ids[start_idx : end_idx + 1]
answer_text = self.processor.tokenizer.decode(answer_ids, skip_special_tokens=True)
return {
"input": {
"text": query_text,
"html": html_content
},
"output": answer_text
}
def main():
model_path = "./markuplm_amazon_qa_token_lora_final"
# Verify model files exist
required_files = ['config.json', 'pytorch_model.bin', 'special_tokens_map.json', 'tokenizer_config.json']
missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
if missing_files:
print(f"Missing required model files: {missing_files}")
return
try:
extractor = MarkupLMExtractor(model_path)
url = ("https://www.amazon.com/s?k=faber+castell+colored+pencils&"
"crid=ADRA090J7SD4&sprefix=%2Caps%2C211&ref=nb_sb_ss_recent_2_0_recent")
query_text = """{
products[] {
product_price
}
}"""
results = extractor.process_url(url, query_text)
if results:
print("\n✅ Results:")
print(json.dumps(results, indent=2))
with open('results.json', 'w') as f:
json.dump(results, f, indent=2)
print("\n💾 Results saved to 'results.json'")
else:
print("No results obtained.")
except Exception as e:
print(f"Error in main: {str(e)}")
if __name__ == "__main__":
main()