forked from lewisfogden/heavylight
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
26a049d
commit da18ee6
Showing
3 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .heavylight import Model, Table | ||
from .memory_optimized_model import LightModel | ||
from .memory_optimized_cache import CacheGraph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import warnings | ||
import types | ||
from inspect import signature, getmembers | ||
import pandas as pd | ||
|
||
class Table: | ||
"""A Table has one or more keys, and a single column of values""" | ||
|
||
def __init__(self, series): | ||
"""Initialise a table with a series - for multiple keys use a multikey index""" | ||
self.series = series | ||
|
||
def __getitem__(self, key): | ||
return self.series.at[key] | ||
|
||
def __call__(self, key): | ||
return self.series.at[key] | ||
|
||
def values(self, key): | ||
return self.series.loc[key].values | ||
|
||
@staticmethod | ||
def read_csv(filename, sep=",", type_identifier = "|"): | ||
"""Read in a table from a csv file, type encoding is via the header: | ||
<keyname1>|<type1>,<keyname2>|<type2>....<value>|<typeN> | ||
where type is one of "str", "int", "float" | ||
""" | ||
column_types = {"str": str, "int": int, "float": float} | ||
with open(filename, "r") as csv_file: | ||
header = next(csv_file).strip("\n").split(sep) | ||
tid = type_identifier | ||
header_mapper_str = {item:item.split(tid)[1] for item in header} | ||
header_mapper_types = {col:column_types[val] for col, val in header_mapper_str.items()} | ||
df = pd.read_csv(filename, sep=sep, dtype=header_mapper_types) | ||
# strip `tid` from column names | ||
df.columns = [col.split("|")[0] for col in df.columns] | ||
df.set_index(list(df.columns[:-1]), inplace=True) | ||
series = df[df.columns[0]] | ||
return Table(series=series) | ||
|
||
|
||
class _Cache: | ||
"""Cache provides controllable memoization for model methods""" | ||
|
||
def __init__(self, func, param_len): | ||
self.func = func | ||
self.param_len = param_len | ||
self.has_one_param = self.param_len == 1 | ||
self._store = dict() | ||
self.__name__ = "Cache: " + func.__name__ | ||
|
||
def __call__(self, *arg): | ||
if arg in self._store: | ||
return self._store[arg] | ||
else: | ||
result = self.func(*arg) | ||
self._store[arg] = result | ||
return result | ||
|
||
def __repr__(self): | ||
return f"<Cache Function: {self.func.__name__} Size: {len(self._store)}>" | ||
|
||
def sum(self): | ||
"""return the sum of all values in the Cache Function""" | ||
return sum(self._store.values()) | ||
|
||
@property | ||
def values(self): | ||
return list(self._store.values()) | ||
|
||
|
||
class Model: | ||
def __init__(self, *, do_run = False, verbose = False, proj_len = None, **kwargs,): | ||
"""Base Class to subclass for user models. | ||
All variables/methods in user models should be lower case, using underscore as spaces. | ||
Class level methods: | ||
RunModel(proj_len): | ||
if the model has not been auto-run at initialisation, run it for projection length. | ||
Special user methods: | ||
BeforeRun(self): | ||
If this is specified in the user model it called before the projection starts, e.g. to set up some specific variables | ||
AfterRun(self): | ||
user method, called after Run is completed, e.g. can use to calculate NPVs of variables | ||
methods/variables to avoid: | ||
methods/variables starting with an underscore `_` are treated as internal. You may break functionality if you create your own. | ||
""" | ||
|
||
self._cached = False | ||
self._is_run = False | ||
if verbose: | ||
print("== Run Parameters ==") | ||
print(" do_run:", do_run) | ||
print(" proj_len:", proj_len) | ||
print() | ||
|
||
if do_run: | ||
if not isinstance(proj_len, int): | ||
raise ValueError("proj_len must be an integer") | ||
elif proj_len <= 0: | ||
raise ValueError("proj_len must have value greater than 0") | ||
self.proj_len = proj_len | ||
else: | ||
if verbose: print("== Not Running - call Run() manually ==") | ||
|
||
if verbose: print("== Storing Arguments ==") | ||
for k, v in kwargs.items(): | ||
if k in dir(self): | ||
warnings.warn("Warning: Duplicate Item: "+str(k)) | ||
setattr(self, k, v) | ||
if verbose: | ||
print(" Updated: ", k, " : ", v) | ||
|
||
# cacheify | ||
if verbose: print("== Caching Functions ==") | ||
self._cache_funcs(verbose) | ||
|
||
if do_run and proj_len > 0: | ||
self.RunModel(proj_len, verbose) | ||
if verbose: print("== Run complete == ") | ||
|
||
|
||
def RunModel(self, proj_len, verbose = False): | ||
if self._is_run: | ||
raise ValueError("Run has already been completed.") | ||
|
||
if verbose: print(f"== Running Projection | length: {proj_len} ==") | ||
|
||
if hasattr(self, "BeforeRun"): | ||
if verbose: print(" Calling BeforeRun") | ||
self.BeforeRun() | ||
|
||
if not self._cached: | ||
raise ValueError("Functions have not been cached") # NB: this shouldn't occur as now caching in instance | ||
for t in range(proj_len): | ||
for name, func in self._funcs.items(): | ||
#func = getattr(self, var) | ||
if func.has_one_param: # skip functions with more than one parameter | ||
func(t) #call each function in turn, starting from t==0 | ||
self._is_run = True | ||
if hasattr(self, "AfterRun"): | ||
if verbose: print(" Calling AfterRun") | ||
return self.AfterRun() | ||
|
||
def _cache_funcs(self, verbose: bool = False): | ||
if self._cached: | ||
raise ValueError("Cache has already been set-up, please create a new instance") | ||
|
||
self._funcs = {} | ||
|
||
for method_name, method in getmembers(self): | ||
#method = getattr(self, method_name) | ||
|
||
if method_name[0] != "_" and method_name[0].islower() and isinstance(method, types.MethodType): | ||
param_count = len(signature(method).parameters) # count the parameters in the function. | ||
cached_method = _Cache(method, param_count) | ||
setattr(self, method_name, cached_method) | ||
self._funcs[method_name] = cached_method | ||
if verbose: print(f" Cached: {method_name}") | ||
|
||
|
||
self._cached = True | ||
|
||
def ToDataFrame(self): | ||
"""return a pandas dataframe of all single parameter columns""" | ||
df = pd.DataFrame() | ||
for func in self._funcs: | ||
if self._funcs[func].has_one_param: | ||
df[func] = pd.Series(self._funcs[func].values) | ||
|
||
# if t is in the dataframe, move it to first position | ||
if "t" in df.columns: | ||
df.insert(0, "t", df.pop("t")) | ||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from heavylight import Model | ||
|
||
class SimpleModel(Model): | ||
def t(self, t): | ||
return t | ||
|
||
def num_pols_if(self, t): | ||
if t == 0: | ||
return 1 | ||
else: | ||
return self.num_pols_if(t - 1) * 0.98 | ||
|
||
def cashflow(self, t): | ||
return self.num_pols_if(t) * 100 | ||
|
||
def v(self, t): | ||
"""discount factor for time t --> time 0""" | ||
if t == 0: | ||
return 1 | ||
else: | ||
return self.v(t - 1) / (1 + self.forward_rate(t)) | ||
|
||
def forward_rate(self, t): | ||
return 0.04 | ||
|
||
def pv_cashflow(self, t): | ||
"""present value of the cashflow occuring at time t""" | ||
return self.cashflow(t) * self.v(t) | ||
|
||
def test_heavylight(): | ||
"""Tests be improved upon later""" | ||
model = SimpleModel(do_run = True, proj_len = 10) | ||
assert model.pv_cashflow.sum() > 0 |