Policy Iteration: How to update the evaluation and improvment correctly?

87 Views Asked by At

I am following David Silver lectures from 2015, and this repo that I forked Practical_RL. In week02, I implemented Value Iteration successfully, but at the end of the notebook when I tested my Policy Iteration, the state_value and policy didn't update correctly. I don't know what I am missing.

Also, I am confused about two things here, Is my implementation for solving a linear equation correct? Because the iterative policy evaluation takes too long. Second, should the policy be Deterministic or Stochastic? In the lecture, it was stated that the policy in PI is deterministic, and in the repo, it was initialized as so, but if we apply greedy policy improvement won't it only look at one action when it could have been multiple optimal actions (multiple actions with equal state_value and reward)? I have been trying to implement PI for days and trying different ideas but to no avail. I hope my question was clear and you can help me, thanks in advance.

Here is the relative code; all the code is in the repo mentioned above at the end of the notebook in week02.

def compute_vpi_deterministic(mdp: MDP, policy, gamma):
    """
    Computes V^pi(s) FOR ALL STATES under given policy.
    :param policy: a dict of currently chosen actions {s : a}
    :param gamma: discount
    :returns: a dict {state : V^pi(state) for all states}
    """
    states = sorted(mdp.get_all_states())
    mapping = {s: i for i, s in enumerate(states)}

    P = np.zeros((mdp.n_states, mdp.n_states))
    for s in states:
        for s_ in states:
            if mdp.is_terminal(s):
                P[mapping[s], mapping[s_]] = 0
                continue

            P[mapping[s], mapping[s_]] = mdp.get_transition_prob(s, policy[s], s_)

    R = np.zeros(mdp.n_states)
    for s in states:
        if mdp.is_terminal(s):
            R[mapping[s]] = 0
            continue

        Rs = [mdp.get_reward(s, policy[s], s_) for s_ in states]
        R[mapping[s]] = sum(Rs) / len(Rs)

    V = np.linalg.solve(np.eye(mdp.n_states) - gamma * P, R)

    return {s: V[i] for i, s in enumerate(states)}
    
#! THIS IMPLEMENTATION DOES NOT WORK
def compute_vpi_iter_deterministic(mdp: MDP, policy, gamma, tol=1e-2, max_iter=10**6):
    """
    Computes V^pi(s) FOR ALL STATES under given policy.
    :param policy: a dict of currently chosen actions {s : a}
    :param gamma: discount
    :param tol: tolerance
    :param max_iter: maximum number of iterations
    :returns: a dict {state : V^pi(state) for all states}
    """
    V = {s: 0 for s in mdp.get_all_states()}
    for _ in range(max_iter):
        new_V = {s: 0 for s in mdp.get_all_states()}

        for s in mdp.get_all_states():
            for s_ in mdp.get_all_states():
                if mdp.is_terminal(s):
                    new_V[s] += 0
                    continue

                new_V[s] += mdp.get_transition_prob(s, policy[s], s_) * \
                            (mdp.get_reward(s, policy[s], s_) + gamma * new_V[s_])

        diff = max(abs(new_V[s] - V[s]) for s in mdp.get_all_states())
        V = new_V

        if diff < tol:
            break
    else:
        print("Terminated: max_iter exceeded")
    
    return V

def compute_new_policy_deterministic(mdp: MDP, vpi, gamma):
    """
    Computes new policy as argmax of state values
    :param vpi: a dict {state : V^pi(state) for all states}
    :returns: a dict {state : optimal action for all states}
    """
    pi = {}

    for s in mdp.get_all_states():
        pi[s] = get_optimal_action(mdp, vpi, s, gamma)

    return pi

def policy_iteration(mdp: MDP, *, compute_vpi, compute_new_policy, policy=None, gamma=0.9, num_iter=1000):
    """ 
    Run the policy iteration loop for num_iter iterations or till the policy is stable.
    If policy is not given, initialize it at random.
    """
    state_values = {s: 0 for s in mdp.get_all_states()}
    policy = policy or {s: None if mdp.is_terminal(s) else
                        np.random.choice(mdp.get_possible_actions(s)) for s in mdp.get_all_states()}
    
    for i in range(num_iter):
        # Implement policy evaluation
        state_values = compute_vpi(mdp, policy, gamma)

        policy_stable = True
        # Implement policy improvement
        new_policy = compute_new_policy(mdp, state_values, gamma)
        
        # If policy is stable
        if policy != new_policy:
            policy_stable = False

        policy = new_policy
        
        if policy_stable:
            break
            
    else:
        print("Iteration limit reached")
        
    return state_values, policy

IMPORTANT EDIT:

So, I was able to apply PI, however, I wasn't able to make it work with iterative policy evaluation. Interestingly, I tried to make my policy stochastic and got amazing results and it even worked with iterative PE. I will still wait on this question in case someone can tell me why it is not working, even though it worked for stochastic policies? What is the misconception I have?

Finally, I changed the code a little bit to match the final code pushed to the repo.

0

There are 0 best solutions below