forked from l-xin/rhms
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
unknown
committed
Mar 14, 2024
1 parent
8d1e45e
commit 6862c8e
Showing
43 changed files
with
144,752 additions
and
140,964 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,43 @@ | ||
# HMS: A Hierarchical Solver with Dependency-Enhanced Understanding for Math Word Problem | ||
Source code for paper *HMS: A Hierarchical Solver with Dependency-Enhanced Understanding for Math Word Problem*. | ||
The source code for paper "Learning Relation-Enhanced Hierarchical Solver for Math Word Problems" is coming soon. | ||
|
||
## Dependencies | ||
- python >= 3.6 | ||
|
||
- stanfordcorenlp | ||
- torch | ||
|
||
## Usage | ||
Preprocess dataset | ||
```bash | ||
python3 src/dataprocess/math23k.py | ||
``` | ||
Train and test model | ||
```bash | ||
python3 src/main.py | ||
``` | ||
For running arguments, please refer to [src/config.py](src/config.py). | ||
# Learning Relation-Enhanced Hierarchical Solver for Math Word Problems | ||
Source code for paper *Learning Relation-Enhanced Hierarchical Solver for Math Word Problems*. | ||
|
||
## Dependencies | ||
- python >= 3.6 | ||
|
||
- stanfordcorenlp | ||
- torch | ||
|
||
## Usage | ||
- Preprocess dataset | ||
```bash | ||
python3 src/dataprocess/math23k.py | ||
python3 src/dataprocess/similarity.py | ||
``` | ||
- Train and test model | ||
```bash | ||
python3 src/rhms/main.py | ||
``` | ||
For running arguments, please refer to [src/rhms/config.py](src/rhms/config.py). | ||
|
||
## Citation | ||
If you find our work helpful, please consider citing our paper. | ||
``` | ||
@article{lin2023learning, | ||
title={Learning Relation-Enhanced Hierarchical Solver for Math Word Problems}, | ||
author={Lin, Xin and Huang, Zhenya and Zhao, Hongke and Chen, Enhong and Liu, Qi and Lian, Defu and Li, Xin and Wang, Hao}, | ||
journal={IEEE Transactions on Neural Networks and Learning Systems}, | ||
year={2023}, | ||
publisher={IEEE} | ||
} | ||
``` | ||
``` | ||
@inproceedings{lin2021hms, | ||
title={HMS: A Hierarchical Solver with Dependency-Enhanced Understanding for Math Word Problem}, | ||
author={Lin, Xin and Huang, Zhenya and Zhao, Hongke and Chen, Enhong and Liu, Qi and Wang, Hao and Wang, Shijin}, | ||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, | ||
volume={35}, | ||
number={5}, | ||
pages={4232--4240}, | ||
year={2021} | ||
} | ||
``` |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,105 +1,105 @@ | ||
# -*- encoding:utf-8 -*- | ||
|
||
def infix_to_postfix(expression): | ||
st = list() | ||
res = list() | ||
priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2} | ||
for e in expression: | ||
if e in ["(", "["]: | ||
st.append(e) | ||
elif e == ")": | ||
c = st.pop() | ||
while c != "(": | ||
res.append(c) | ||
c = st.pop() | ||
elif e == "]": | ||
c = st.pop() | ||
while c != "[": | ||
res.append(c) | ||
c = st.pop() | ||
elif e in priority: | ||
while len(st) > 0 and st[-1] not in ["(", "["] and priority[e] <= priority[st[-1]]: | ||
res.append(st.pop()) | ||
st.append(e) | ||
else: | ||
res.append(e) | ||
while len(st) > 0: | ||
res.append(st.pop()) | ||
return res | ||
|
||
def postfix_to_prefix(post_equ, check=False): | ||
op_list = set(["+", "-", "*", "/", "^"]) | ||
stack = [] | ||
for elem in post_equ: | ||
sub_stack = [] | ||
if elem not in op_list: | ||
sub_stack.append(elem) | ||
stack.append(sub_stack) | ||
else: | ||
if len(stack) >= 2: | ||
opnds = reversed([stack.pop() for i in range(2)]) | ||
sub_stack.append(elem) | ||
for opnd in opnds: | ||
sub_stack.extend(opnd) | ||
stack.append(sub_stack) | ||
if check and len(stack) != 1: | ||
pre_equ = None | ||
else: | ||
pre_equ = stack.pop() | ||
return pre_equ | ||
|
||
def post_solver(post_equ): | ||
op_list = set(['+', '-', '/', '*', '^']) | ||
status = True | ||
stack = [] | ||
for elem in post_equ: | ||
if elem in op_list: | ||
if len(stack) >= 2: | ||
op = elem | ||
opnd2 = stack.pop() | ||
opnd1 = stack.pop() | ||
if op == '+': | ||
answer = opnd1 + opnd2 | ||
elif op == '-': | ||
answer = opnd1 - opnd2 | ||
elif op == '*': | ||
answer = opnd1 * opnd2 | ||
elif op == '/': | ||
answer = opnd1 / opnd2 | ||
elif op == '^': | ||
answer = opnd1 ** opnd2 | ||
else: | ||
status = False | ||
break | ||
stack.append(answer) | ||
else: | ||
status = False | ||
break | ||
else: | ||
elem = float(elem) | ||
stack.append(elem) | ||
if status and len(stack) == 1: | ||
answer = stack.pop() | ||
else: | ||
answer = None | ||
status = False | ||
return status, answer | ||
|
||
def number_map(equ, num_list): | ||
num_equ = [] | ||
for token in equ: | ||
if "temp_" in token: | ||
token = num_list[ord(token[-1]) - ord('a')] | ||
elif token == "PI": | ||
token = 3.14 | ||
num_equ.append(token) | ||
return num_equ | ||
|
||
def eval_num_list(str_num_list): | ||
num_list = list() | ||
for item in str_num_list: | ||
if item[-1] == "%": | ||
num_list.append(float(item[:-1]) / 100) | ||
else: | ||
num_list.append(eval(item)) | ||
return num_list | ||
# -*- encoding:utf-8 -*- | ||
|
||
def infix_to_postfix(expression): | ||
st = list() | ||
res = list() | ||
priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2} | ||
for e in expression: | ||
if e in ["(", "["]: | ||
st.append(e) | ||
elif e == ")": | ||
c = st.pop() | ||
while c != "(": | ||
res.append(c) | ||
c = st.pop() | ||
elif e == "]": | ||
c = st.pop() | ||
while c != "[": | ||
res.append(c) | ||
c = st.pop() | ||
elif e in priority: | ||
while len(st) > 0 and st[-1] not in ["(", "["] and priority[e] <= priority[st[-1]]: | ||
res.append(st.pop()) | ||
st.append(e) | ||
else: | ||
res.append(e) | ||
while len(st) > 0: | ||
res.append(st.pop()) | ||
return res | ||
|
||
def postfix_to_prefix(post_equ, check=False): | ||
op_list = set(["+", "-", "*", "/", "^"]) | ||
stack = [] | ||
for elem in post_equ: | ||
sub_stack = [] | ||
if elem not in op_list: | ||
sub_stack.append(elem) | ||
stack.append(sub_stack) | ||
else: | ||
if len(stack) >= 2: | ||
opnds = reversed([stack.pop() for i in range(2)]) | ||
sub_stack.append(elem) | ||
for opnd in opnds: | ||
sub_stack.extend(opnd) | ||
stack.append(sub_stack) | ||
if check and len(stack) != 1: | ||
pre_equ = None | ||
else: | ||
pre_equ = stack.pop() | ||
return pre_equ | ||
|
||
def post_solver(post_equ): | ||
op_list = set(['+', '-', '/', '*', '^']) | ||
status = True | ||
stack = [] | ||
for elem in post_equ: | ||
if elem in op_list: | ||
if len(stack) >= 2: | ||
op = elem | ||
opnd2 = stack.pop() | ||
opnd1 = stack.pop() | ||
if op == '+': | ||
answer = opnd1 + opnd2 | ||
elif op == '-': | ||
answer = opnd1 - opnd2 | ||
elif op == '*': | ||
answer = opnd1 * opnd2 | ||
elif op == '/': | ||
answer = opnd1 / opnd2 | ||
elif op == '^': | ||
answer = opnd1 ** opnd2 | ||
else: | ||
status = False | ||
break | ||
stack.append(answer) | ||
else: | ||
status = False | ||
break | ||
else: | ||
elem = float(elem) | ||
stack.append(elem) | ||
if status and len(stack) == 1: | ||
answer = stack.pop() | ||
else: | ||
answer = None | ||
status = False | ||
return status, answer | ||
|
||
def number_map(equ, num_list): | ||
num_equ = [] | ||
for token in equ: | ||
if "temp_" in token: | ||
token = num_list[ord(token[-1]) - ord('a')] | ||
elif token == "PI": | ||
token = 3.14 | ||
num_equ.append(token) | ||
return num_equ | ||
|
||
def eval_num_list(str_num_list): | ||
num_list = list() | ||
for item in str_num_list: | ||
if item[-1] == "%": | ||
num_list.append(float(item[:-1]) / 100) | ||
else: | ||
num_list.append(eval(item)) | ||
return num_list |
Oops, something went wrong.