Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Jan 27, 2024
1 parent 96574d7 commit d37d176
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 572 deletions.
91 changes: 26 additions & 65 deletions arxiv_bot/knowledge_base/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import tiktoken
from tqdm.auto import tqdm

import logging
from arxiv_bot.utils.logger import logger
from typing import Union, Any, Optional

paper_id_re = re.compile(r'https://arxiv.org/abs/(\d+\.\d+)')
Expand Down Expand Up @@ -82,7 +82,7 @@ def init_extractor(
openai_api_key = openai_api_key or os.environ['OPENAI_API_KEY']
# instantiate the OpenAI API wrapper
llm = OpenAI(
model_name='text-davinci-003',
model_name='gpt-3.5-turbo-instruct',
openai_api_key=openai_api_key,
max_tokens=max_tokens,
temperature=0.0
Expand Down Expand Up @@ -164,6 +164,7 @@ class Arxiv:
Extracted:
"""
llm = None
get_id = re.compile(r'(?<=arxiv:)\d{4}.\d{5}')

def __init__(self, paper_id: str):
"""Object to handle the extraction of an ArXiv paper and its
Expand All @@ -188,7 +189,7 @@ def load(self, save: bool = False):
"""
# check if pdf already exists
if os.path.exists(f'papers/{self.id}.json'):
print(f'Loading papers/{self.id}.json from file')
logger.info(f'Loading papers/{self.id}.json from file')
with open(f'papers/{self.id}.json', 'r') as fp:
attributes = json.loads(fp.read())
for key, value in attributes.items():
Expand All @@ -204,7 +205,7 @@ def load(self, save: bool = False):
if save:
self.save()

def get_refs(self, extractor, text_splitter):
def get_refs(self):
"""Get the references for the paper.
:param extractor: The LLMChain extractor model
Expand All @@ -214,41 +215,14 @@ def get_refs(self, extractor, text_splitter):
:return: The references for the paper
:rtype: list
"""
logger.info(f'Extracting references for {self.id}')
if len(self.references) == 0:
self._download_refs(extractor, text_splitter)
content = self.content.lower()
matches = self.get_id.findall(content)
matches = list(set(matches))
self.references = [{"id": m} for m in matches]
logger.info(f'Found {len(self.references)} references')
return self.references

def _download_refs(self, extractor, text_splitter):
"""Download the references for the paper. Stores them in
the self.references attribute.
:param extractor: The LLMChain extractor model
:type extractor: LLMChain
:param text_splitter: The text splitter to use
:type text_splitter: TokenTextSplitter
"""
# get references section of paper
refs = self.refs_re.split(self.content)[-1]
# we don't need the full thing, just the first page
refs_page = text_splitter.split_text(refs)[0]
# use LLM extractor to extract references
out = extractor.run(refs=refs_page)
out = out.split('\n')
out = [o for o in out if o != '']
# with list of references, find the paper IDs
ids = [get_paper_id(o) for o in out]
# clean up into JSONL type format
out = [o.split(' | ') for o in out]
# in case we're missing some fields
out = [o for o in out if len(o) == 3]
meta = [{
'id': _id,
'title': o[0],
'authors': o[1],
'year': o[2]
} for o, _id in zip(out, ids) if _id is not None]
logging.debug(f"Extracted {len(meta)} references")
self.references = meta

def _convert_pdf_to_text(self):
"""Convert the PDF to text and store it in the self.content
Expand Down Expand Up @@ -292,7 +266,7 @@ def _download_meta(self):
self.summary = result.summary
self.title = result.title
self.updated = result.updated.strftime('%Y%m%d')
logging.debug(f"Downloaded metadata for paper '{self.id}'")
logger.info(f"Downloaded metadata for paper '{self.id}'")

def save(self):
"""Save the paper to a local JSON file.
Expand All @@ -316,12 +290,14 @@ def save_chunks(
"""
if not os.path.exists(path):
os.makedirs(path)
with open(f'{path}/{self.id}.jsonl', 'w') as fp:
with open(f'{path}/{self.id}-chunks.jsonl', 'w') as fp:
for chunk in self.dataset:
if include_metadata:
chunk.update(self.get_meta())
fp.write(json.dumps(chunk) + '\n')
logging.debug(f"Saved paper to '{path}/{self.id}.jsonl'")
logger.info(f"Saved paper to '{path}/{self.id}.jsonl'")
with open(f"{path}/{self.id}.jsonl", "w") as fp:
fp.write(json.dumps(self.__dict__()))

def get_meta(self):
"""Returns the meta information for the paper.
Expand All @@ -348,7 +324,7 @@ def chunker(self, chunk_size=300):
'chunk-id': str(i),
'chunk': chunk
})
logging.debug(f"Split paper into {len(paper_chunks)} chunks")
logger.info(f"Split paper into {len(paper_chunks)} chunks")
self.dataset = langchain_dataset

def _clean_text(self, text):
Expand Down Expand Up @@ -380,8 +356,6 @@ class ArxivGraphScraper:
def __init__(
self,
paper_id: str,
extractor: Any,
text_splitter: Any,
levels: int = 3,
save_location: str = 'chunks',
verbose: bool = False
Expand All @@ -402,23 +376,9 @@ def __init__(
if not os.path.exists(self.save_location):
os.mkdir(self.save_location)
# save objects required for ref extraction
self.extractor = extractor
self.text_splitter = text_splitter
ids = [paper_id]
for level in tqdm(range(levels)):
ids = self._build_papers(ids)
# set logging level
if verbose:
logging.basicConfig(
format='[%(filename)s:%(lineno)d] %(message)s',
level=logging.DEBUG
)
# further logging options for when working with ipython+jupyter
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(filename)s:%(lineno)d] %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

def _create_paper(self, paper_id: str):
"""Create a paper object from a paper ID.
Expand All @@ -428,15 +388,14 @@ def _create_paper(self, paper_id: str):
:return: The paper object
:rtype: ArxivPaper
"""
print(f"Loading '{paper_id}'")
logger.info(f"Loading '{paper_id}'")
paper = Arxiv(paper_id)
paper.load()
paper.get_meta()
logger.info(f"Getting references for paper '{paper.title}'")
# get references
refs = paper.get_refs(
extractor=self.extractor,
text_splitter=self.text_splitter
)
refs = paper.get_refs()
logger.info(f"Found {len(refs)} references")
paper.chunker()
paper.save_chunks(include_metadata=True, path=self.save_location)
return paper
Expand All @@ -451,10 +410,12 @@ def _build_papers(self, paper_ids: list):
"""
ids = []
for _id in tqdm(paper_ids):
paper = self._create_paper(_id)
ids.extend([r['id'] for r in paper.references])
try:
paper = self._create_paper(_id)
ids.extend([r['id'] for r in paper.references])
except Exception as e:
logger.warning(e)
original_ids = set(paper_ids)
new_ids = set(ids)
new_ids = list(new_ids - original_ids)
logging.debug(f"Found {len(new_ids)} new papers")
return new_ids
44 changes: 44 additions & 0 deletions arxiv_bot/utils/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging

import colorlog


class CustomFormatter(colorlog.ColoredFormatter):
def __init__(self):
super().__init__(
"%(log_color)s%(asctime)s %(levelname)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
},
reset=True,
style="%",
)


def setup_custom_logger(name):
formatter = CustomFormatter()

console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

logging.basicConfig(
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
format="%(asctime)s %(levelname)s %(message)s",
force=True,
)

logger = logging.getLogger(name)
logger.handlers = []
logger.addHandler(console_handler)
logger.propagate = False

return logger


logger = setup_custom_logger("__name__")
Loading

0 comments on commit d37d176

Please sign in to comment.