Skip to content

Commit

Permalink
Fix exploitable regexes in Nougat and GPTSan/GPTJNeoXJapanese (#36121)
Browse files Browse the repository at this point in the history
* Fix potential regex catastrophic backtracking in NougatTokenizerFast

The original regex pattern in tokenization_nougat_fast.py was vulnerable to
catastrophic backtracking due to greedy quantifiers and nested alternations.
This commit replaces it with a more efficient pattern that:

1. Uses explicit character classes instead of dot (.)
2. Handles whitespace more precisely
3. Avoids unnecessary backtracking
4. Supports both lowercase and uppercase roman numerals
5. Maintains the same functionality while being more robust

* Try another regex

* Trying deepseek's answer

* Start with a simplification

* Another simplification

* Just rewrite the whole function myself

* Fix gptneox and gptsan

* Simplify the regex even further

* Tighten up the price regex a little

* Add possessive version of the regex

* Fix regex

* Much cleaner regexes

---------

Co-authored-by: openhands <[email protected]>
  • Loading branch information
Rocketknight1 and openhands-agent authored Feb 21, 2025
1 parent 547911e commit 92c5ca9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import re
import sys
from typing import List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -407,9 +408,23 @@ def __init__(self, vocab, ids_to_tokens, emoji):
self.content_repatter5 = re.compile(
r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
)
self.content_repatter6 = re.compile(
r"((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*億)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*万)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*千)*(0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*(千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+(\(税込\)|\(税抜\)|\+tax)*"
)
# The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
# possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
# different regex that should generally have the same behaviour in most non-pathological cases.
if sys.version_info >= (3, 11):
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億])*+"
r"(?:\d,\d{3}|[\d万])*+"
r"(?:\d,\d{3}|[\d千])*+"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
else:
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億万千])*"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
self.content_trans1 = str.maketrans({k: "<BLOCK>" for k in keisen + blocks})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import re
import sys
from typing import Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -230,9 +231,23 @@ def __init__(self, vocab, ids_to_tokens, emoji):
self.content_repatter5 = re.compile(
r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
)
self.content_repatter6 = re.compile(
r"((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*億)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*万)*((0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*千)*(0|[1-9]\d*|[1-9]\d{0,2}(,\d{3})+)*(千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+(\(税込\)|\(税抜\)|\+tax)*"
)
# The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
# possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
# different regex that should generally have the same behaviour in most non-pathological cases.
if sys.version_info >= (3, 11):
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億])*+"
r"(?:\d,\d{3}|[\d万])*+"
r"(?:\d,\d{3}|[\d千])*+"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
else:
self.content_repatter6 = re.compile(
r"(?:\d,\d{3}|[\d億万千])*"
r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
r"(?:\(税込\)|\(税抜\)|\+tax)*"
)
keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
self.content_trans1 = str.maketrans({k: "<BLOCK>" for k in keisen + blocks})
Expand Down
39 changes: 15 additions & 24 deletions src/transformers/models/nougat/tokenization_nougat_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,17 @@ def normalize_list_like_lines(generation):
normalization adjusts the bullet point style and nesting levels based on the captured patterns.
"""

# This matches lines starting with - or *, not followed by - or * (lists)
# that are then numbered by digits \d or roman numerals (one or more)
# and then, optional additional numbering of this line is captured
# this is then fed to re.finditer.
pattern = r"(?:^)(-|\*)?(?!-|\*) ?((?:\d|[ixv])+ )?.+? (-|\*) (((?:\d|[ixv])+)\.(\d|[ixv]) )?.*(?:$)"

for match in reversed(list(re.finditer(pattern, generation, flags=re.I | re.M))):
start, stop = match.span()
delim = match.group(3) + " "
splits = match.group(0).split(delim)
lines = generation.split("\n")
output_lines = []
for line_no, line in enumerate(lines):
match = re.search(r". ([-*]) ", line)
if not match or line[0] not in ("-", "*"):
output_lines.append(line)
continue # Doesn't fit the pattern we want, no changes
delim = match.group(1) + " "
splits = line.split(delim)[1:]
replacement = ""

if match.group(1) is not None:
splits = splits[1:]
delim1 = match.group(1) + " "
else:
delim1 = ""
continue # Skip false positives

pre, post = generation[:start], generation[stop:]
delim1 = line[0] + " "

for i, item in enumerate(splits):
level = 0
Expand All @@ -144,15 +135,15 @@ def normalize_list_like_lines(generation):
level = potential_numeral.count(".")

replacement += (
("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or start == 0 else delim1) + item.strip()
("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip()
)

if post == "":
post = "\n"
if line_no == len(lines) - 1: # If this is the last line in the generation
replacement += "\n" # Add an empty line to the end of the generation

generation = pre + replacement + post
output_lines.append(replacement)

return generation
return "\n".join(output_lines)


def find_next_punctuation(text: str, start_idx=0):
Expand Down

0 comments on commit 92c5ca9

Please sign in to comment.