Skip to content

Commit

Permalink
Improve A2C with better state
Browse files Browse the repository at this point in the history
  • Loading branch information
mvanaltvorst committed Jul 6, 2024
1 parent 6e62c97 commit 0515255
Show file tree
Hide file tree
Showing 19 changed files with 11,275 additions and 12,961 deletions.
32 changes: 17 additions & 15 deletions snake-a2c/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))

def act(self, state):
_, policy_dist = self.model(torch.tensor(state, dtype = torch.float32))
_, policy_dist = self.model(torch.tensor(state, dtype=torch.float32))
dist = torch.distributions.Categorical(probs=policy_dist)
action = dist.sample().detach().numpy().item()
return action
Expand All @@ -40,12 +40,14 @@ def replay(self, batch_size):

minibatch = random.sample(self.memory, batch_size)

states = torch.tensor([x[0] for x in minibatch], dtype = torch.float32).to(device)
states = torch.tensor([x[0] for x in minibatch], dtype=torch.float32).to(device)
actions = torch.tensor([x[1] for x in minibatch]).unsqueeze(1).to(device)
rewards = torch.tensor([x[2] for x in minibatch]).to(
device, dtype=torch.float32
)
next_states = torch.tensor([x[3] for x in minibatch], dtype = torch.float32).to(device)
next_states = torch.tensor([x[3] for x in minibatch], dtype=torch.float32).to(
device
)
dones = torch.tensor([x[4] for x in minibatch]).to(device, dtype=torch.float32)

values, policy_dists = self.model(states)
Expand All @@ -54,7 +56,7 @@ def replay(self, batch_size):
returns = rewards + self.gamma * next_values.squeeze() * (1 - dones)
advantages = returns - values.squeeze()

log_probs = torch.log(policy_dists.gather(1, actions))
log_probs = torch.log(policy_dists.gather(-1, actions))
actor_loss = -(log_probs * advantages.detach()).mean()

critic_loss = (advantages**2).mean()
Expand All @@ -73,22 +75,22 @@ def save(self, path: Path | str):

path.parent.mkdir(parents=True, exist_ok=True)
state = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'gamma': self.gamma,
'lr': self.lr,
'maxlen': self.maxlen,
'memory': list(self.memory)
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"gamma": self.gamma,
"lr": self.lr,
"maxlen": self.maxlen,
"memory": list(self.memory),
}
torch.save(state, path)
print(f"Agent saved to {path}")

@classmethod
def load(cls, path: Path | str):
state = torch.load(path)
agent = cls(gamma=state['gamma'], lr=state['lr'], maxlen = state['maxlen'])
agent.model.load_state_dict(state['model_state_dict'])
agent.optimizer.load_state_dict(state['optimizer_state_dict'])
agent.memory = deque(state['memory'], maxlen=agent.memory.maxlen)
agent = cls(gamma=state["gamma"], lr=state["lr"], maxlen=state["maxlen"])
agent.model.load_state_dict(state["model_state_dict"])
agent.optimizer.load_state_dict(state["optimizer_state_dict"])
agent.memory = deque(state["memory"], maxlen=agent.memory.maxlen)
print(f"Agent loaded from {path}")
return agent
return agent
Binary file modified snake-a2c/agent_10x10.state
Binary file not shown.
2 changes: 1 addition & 1 deletion snake-a2c/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class ActorCritic(nn.Module):
def __init__(self, input_size: int = 10, output_size: int = 4):
def __init__(self, input_size: int = 11, output_size: int = 3):
super().__init__()
self.input_size = input_size
self.output_size = output_size
Expand Down
Binary file modified snake-a2c/rewards_10x10.state
Binary file not shown.
Loading

0 comments on commit 0515255

Please sign in to comment.