Skip to content

Commit

Permalink
add src and data
Browse files Browse the repository at this point in the history
  • Loading branch information
unknown committed Mar 14, 2024
1 parent 8d1e45e commit 6862c8e
Show file tree
Hide file tree
Showing 43 changed files with 144,752 additions and 140,964 deletions.
63 changes: 43 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,43 @@
# HMS: A Hierarchical Solver with Dependency-Enhanced Understanding for Math Word Problem
Source code for paper *HMS: A Hierarchical Solver with Dependency-Enhanced Understanding for Math Word Problem*.
The source code for paper "Learning Relation-Enhanced Hierarchical Solver for Math Word Problems" is coming soon.

## Dependencies
- python >= 3.6

- stanfordcorenlp
- torch

## Usage
Preprocess dataset
```bash
python3 src/dataprocess/math23k.py
```
Train and test model
```bash
python3 src/main.py
```
For running arguments, please refer to [src/config.py](src/config.py).
# Learning Relation-Enhanced Hierarchical Solver for Math Word Problems
Source code for paper *Learning Relation-Enhanced Hierarchical Solver for Math Word Problems*.

## Dependencies
- python >= 3.6

- stanfordcorenlp
- torch

## Usage
- Preprocess dataset
```bash
python3 src/dataprocess/math23k.py
python3 src/dataprocess/similarity.py
```
- Train and test model
```bash
python3 src/rhms/main.py
```
For running arguments, please refer to [src/rhms/config.py](src/rhms/config.py).

## Citation
If you find our work helpful, please consider citing our paper.
```
@article{lin2023learning,
title={Learning Relation-Enhanced Hierarchical Solver for Math Word Problems},
author={Lin, Xin and Huang, Zhenya and Zhao, Hongke and Chen, Enhong and Liu, Qi and Lian, Defu and Li, Xin and Wang, Hao},
journal={IEEE Transactions on Neural Networks and Learning Systems},
year={2023},
publisher={IEEE}
}
```
```
@inproceedings{lin2021hms,
title={HMS: A Hierarchical Solver with Dependency-Enhanced Understanding for Math Word Problem},
author={Lin, Xin and Huang, Zhenya and Zhao, Hongke and Chen, Enhong and Liu, Qi and Wang, Hao and Wang, Shijin},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={35},
number={5},
pages={4232--4240},
year={2021}
}
```
12,002 changes: 6,001 additions & 6,001 deletions data/test23k.json

Large diffs are not rendered by default.

253,946 changes: 126,973 additions & 126,973 deletions data/train23k.json

Large diffs are not rendered by default.

12,002 changes: 6,001 additions & 6,001 deletions data/valid23k.json

Large diffs are not rendered by default.

210 changes: 105 additions & 105 deletions src/dataprocess/equ_tools.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,105 @@
# -*- encoding:utf-8 -*-

def infix_to_postfix(expression):
st = list()
res = list()
priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2}
for e in expression:
if e in ["(", "["]:
st.append(e)
elif e == ")":
c = st.pop()
while c != "(":
res.append(c)
c = st.pop()
elif e == "]":
c = st.pop()
while c != "[":
res.append(c)
c = st.pop()
elif e in priority:
while len(st) > 0 and st[-1] not in ["(", "["] and priority[e] <= priority[st[-1]]:
res.append(st.pop())
st.append(e)
else:
res.append(e)
while len(st) > 0:
res.append(st.pop())
return res

def postfix_to_prefix(post_equ, check=False):
op_list = set(["+", "-", "*", "/", "^"])
stack = []
for elem in post_equ:
sub_stack = []
if elem not in op_list:
sub_stack.append(elem)
stack.append(sub_stack)
else:
if len(stack) >= 2:
opnds = reversed([stack.pop() for i in range(2)])
sub_stack.append(elem)
for opnd in opnds:
sub_stack.extend(opnd)
stack.append(sub_stack)
if check and len(stack) != 1:
pre_equ = None
else:
pre_equ = stack.pop()
return pre_equ

def post_solver(post_equ):
op_list = set(['+', '-', '/', '*', '^'])
status = True
stack = []
for elem in post_equ:
if elem in op_list:
if len(stack) >= 2:
op = elem
opnd2 = stack.pop()
opnd1 = stack.pop()
if op == '+':
answer = opnd1 + opnd2
elif op == '-':
answer = opnd1 - opnd2
elif op == '*':
answer = opnd1 * opnd2
elif op == '/':
answer = opnd1 / opnd2
elif op == '^':
answer = opnd1 ** opnd2
else:
status = False
break
stack.append(answer)
else:
status = False
break
else:
elem = float(elem)
stack.append(elem)
if status and len(stack) == 1:
answer = stack.pop()
else:
answer = None
status = False
return status, answer

def number_map(equ, num_list):
num_equ = []
for token in equ:
if "temp_" in token:
token = num_list[ord(token[-1]) - ord('a')]
elif token == "PI":
token = 3.14
num_equ.append(token)
return num_equ

def eval_num_list(str_num_list):
num_list = list()
for item in str_num_list:
if item[-1] == "%":
num_list.append(float(item[:-1]) / 100)
else:
num_list.append(eval(item))
return num_list
# -*- encoding:utf-8 -*-

def infix_to_postfix(expression):
st = list()
res = list()
priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2}
for e in expression:
if e in ["(", "["]:
st.append(e)
elif e == ")":
c = st.pop()
while c != "(":
res.append(c)
c = st.pop()
elif e == "]":
c = st.pop()
while c != "[":
res.append(c)
c = st.pop()
elif e in priority:
while len(st) > 0 and st[-1] not in ["(", "["] and priority[e] <= priority[st[-1]]:
res.append(st.pop())
st.append(e)
else:
res.append(e)
while len(st) > 0:
res.append(st.pop())
return res

def postfix_to_prefix(post_equ, check=False):
op_list = set(["+", "-", "*", "/", "^"])
stack = []
for elem in post_equ:
sub_stack = []
if elem not in op_list:
sub_stack.append(elem)
stack.append(sub_stack)
else:
if len(stack) >= 2:
opnds = reversed([stack.pop() for i in range(2)])
sub_stack.append(elem)
for opnd in opnds:
sub_stack.extend(opnd)
stack.append(sub_stack)
if check and len(stack) != 1:
pre_equ = None
else:
pre_equ = stack.pop()
return pre_equ

def post_solver(post_equ):
op_list = set(['+', '-', '/', '*', '^'])
status = True
stack = []
for elem in post_equ:
if elem in op_list:
if len(stack) >= 2:
op = elem
opnd2 = stack.pop()
opnd1 = stack.pop()
if op == '+':
answer = opnd1 + opnd2
elif op == '-':
answer = opnd1 - opnd2
elif op == '*':
answer = opnd1 * opnd2
elif op == '/':
answer = opnd1 / opnd2
elif op == '^':
answer = opnd1 ** opnd2
else:
status = False
break
stack.append(answer)
else:
status = False
break
else:
elem = float(elem)
stack.append(elem)
if status and len(stack) == 1:
answer = stack.pop()
else:
answer = None
status = False
return status, answer

def number_map(equ, num_list):
num_equ = []
for token in equ:
if "temp_" in token:
token = num_list[ord(token[-1]) - ord('a')]
elif token == "PI":
token = 3.14
num_equ.append(token)
return num_equ

def eval_num_list(str_num_list):
num_list = list()
for item in str_num_list:
if item[-1] == "%":
num_list.append(float(item[:-1]) / 100)
else:
num_list.append(eval(item))
return num_list
Loading

0 comments on commit 6862c8e

Please sign in to comment.