Reinforcement Learning as Probabilistic Inference - Part 3

From the optimal action conditionals, we recover the optimal policy through backward messages, relate it to value functions in RL, and connect probabilistic inference to maximum entropy reinforcement learning.

Recovering the Optimal Policy (cont.)

We ended part 2 with an introduction to the optimal action conditional p(atst,Ot:T=1)p(a_t \mid s_t, O_{t:T} =1). While it is not quite the action conditionals given the optimal policy p(atst,θ)p(a_t \mid s_t, \theta^*) that we are interested in, it is attainable from our PGM and closely related; Both p(atst,Ot:T=1)p(a_t \mid s_t, O_{t:T} = 1) and p(atst,θ)p(a_t \mid s_t, \theta^\star) aim to select actions that lead to optimal outcomes. We will go through deriving the optimal action conditional from our probability graph model which will in the end reveal the subtle difference.

The sum-product inference algorithm

We will be using the standard sum-product inference algorithm, which is analogous to inference in Hidden Markov Model or Dynamic Bayesian networks (DBN) [1]. The sum-product algorithm computes marginal probabilities in a graphical model. In a sequential model like a DBN, the sum-product algorithm involves forward and backward passes to compute marginal probabilities: Each message represents a partial computation of the marginal probability, based on the information available in the neighborhood of the node. More here [2]. The forward pass indicates messages propagate forward in time, capturing how past states influence the present. The backward pass indicate messages propagate backward in time, capturing how future states constrain the present.

The backward message

In our PGM, the backward message is

βt(st,at)=p(Ot:Tst,at)orβt(st)=p(Ot:Tst). \beta_t(s_t, a_t) = p(O_{t:T} \mid s_t, a_t) \quad \text{or} \quad \beta_t(s_t) = p(O_{t:T} \mid s_t). That is the probability that a trajectory can be optimal for time steps from t to T if it begins in state sts_t with the action ata_t. Alternatively, the state-only message βt(st)\beta_t(s_t) indicates the probability that the trajectory from t to T is optimal if it begins in state sts_t. The state-only message can be derived from the state-action message by integrating out the action:

βt(st)=p(Ot:Tst)=Ap(Ot:Tst,at)p(atst)dat=Aβt(st,at)p(atst)dat. \beta_t(s_t) = p(O_{t:T} \mid s_t) = \int_A p(O_{t:T} \mid s_t, a_t) p(a_t \mid s_t) \, da_t = \int_A \beta_t(s_t, a_t) p(a_t \mid s_t) \, da_t.

* Note that p(atst)p(a_t \mid s_t) is the action prior. Our PGM doesn’t actually contain this factor, and we can assume that it is a constant corresponding to a uniform distribution over the set of actions for simplicity p(atst)=1Ap(a_t \mid s_t) = \frac{1}{|A|}

The recursive message passing algorithm for computing βt(st,at)\beta_t(s_t, a_t) proceeds from the last time step t=Tt = T backward through time to t=1t = 1:

βt(st,at)=p(Ot:Tst,at)=Sβt+1(st+1)p(st+1st,at)p(Otst,at)dst+1. \begin{aligned} \fcolorbox{red}{white}{$\beta_t(s_t, a_t)$} &= p(O_{t:T} \mid s_t, a_t) \\ &= \int_{S} \fcolorbox{red}{white}{$\beta_{t+1}(s_{t+1})$} p(s_{t+1} \mid s_t, a_t) \, p(O_t \mid s_t, a_t)ds_{t+1}. \end{aligned}

* In the base case, note that p(OTsT,aT)p(O_T \mid s_T, a_T) is simply proportional to exp(r(sT,aT))\exp(r(s_T, a_T)), since there is only one factor to consider.

The backward message can be computed recursively:

βt(st)=atst+1p(Ot=1st,at)p(st+1st,at)βt+1(st+1) \begin{aligned} \beta_t(s_t) &= \sum_{a_t} \sum_{s_{t+1}} \fcolorbox{blue}{white}{$p(O_t = 1 \mid s_t, a_t)$} \fcolorbox{green}{white}{$p(s_{t+1} \mid s_t, a_t)$} \beta_{t+1}(s_{t+1}) \end{aligned}

where

The backward message βt(st)\beta_t(s_t) tells us how likely it is to achieve high rewards from time t+1t+1 onward, given the current state sts_t. Combined with forward messages (prior dynamics), this allows the agent to infer the best action ata_t at any time step tt.

The optimal action conditional (finally!)

From these backward messages, we can then derive the optimal action conditional p(atst,O1:T)p(a_t \mid s_t, O_{1:T}). First, note that O1:(t1)O_{1:(t-1)} is conditionally independent of ata_t given sts_t, which means that:

p(atst,O1:T)=p(atst,Ot:T), p(a_t \mid s_t, O_{1:T}) = p(a_t \mid s_t, O_{t:T}),

and we can disregard the past when considering the current action distribution. This makes sense because in a Markovian system, the optimal action does not depend on the past.

We can recover the optimal action distribution using the two backward messages:

p(atst,Ot:T)=p(st,atOt:T)p(stOt:T)=p(Ot:Tst,at)p(atst)p(st)p(Ot:Tst)p(st)(by Bayes’ rule)p(Ot:Tst,at)p(Ot:Tst)(Assume uniform p(atst)=1A, so p(atst) cancels out)=βt(st,at)βt(st)(1) \begin{aligned} p(a_t \mid s_t, O_{t:T}) &= \frac{p(s_t, a_t \mid O_{t:T})}{p(s_t \mid O_{t:T})} \\ &= \frac{p(O_{t:T} \mid s_t, a_t) p(a_t \mid s_t) p(s_t)}{p(O_{t:T} \mid s_t) p(s_t)} &\quad \textcolor{gray}{\scriptsize \text{(by Bayes' rule)}} \\ &\propto \frac{p(O_{t:T} \mid s_t, a_t)}{p(O_{t:T} \mid s_t)} &\quad \textcolor{gray}{\scriptsize \text{(Assume uniform } p(a_t \mid s_t) = \frac{1}{|A|} \text{, so } p(a_t \mid s_t) \text{ cancels out)}} \\ &= \fcolorbox{red}{white}{$\frac{\beta_t(s_t, a_t)}{\beta_t(s_t)}$} \tag{1} \end{aligned}

Backward message \leftrightarrow Value function, State-action value function

We can relate the backward message in our PGM to Q value (state-action value) and value functions in the context of RL. We know that βt(st,at)=p(Ot:Tst,at)=p(O1:Tst,at)=exp(r(st,at)) \beta_t(s_t, a_t) = p(O_{t:T} \mid s_t, a_t) = p(O_{1:T} \mid s_t, a_t) = \exp(r(s_t, a_t)) βt(st,at)\beta_t(s_t, a_t) represents the unnormalized probability of achieving optimality from tt onward, given (st,at)(s_t, a_t), and similarly βt(st)\beta_t(s_t) summarizes the unnormalized probability of achieving optimality from tt onward, given sts_t. The logarithm of the backward message maps to the state-action value function Q(st,at)Q(s_t, a_t), which is defined as the expected cumulative reward starting at (st,at)(s_t,a_t) and value function V(st)V(s_t), which is the expected cumulative reward starting at sts_t​, marginalizing over all actions.

Q(st,at)=logβt(st,at),V(st)=logβt(st). Q(s_t, a_t) = \log \beta_t(s_t, a_t), \quad V(s_t) = \log \beta_t(s_t).

Substituting the definitions of Q(st,at)Q(s_t, a_t) and V(st)V(s_t) into the optimal action conditional (1):

p(atst,Ot:T)=βt(st,at)βt(st)=exp(Q(st,at))exp(V(st))=exp(Q(st,at)V(st)).(2) p(a_t \mid s_t, O_{t:T}) = \frac{\beta_t(s_t, a_t)}{\beta_t(s_t)} = \frac{\exp(Q(s_t, a_t))}{\exp(V(s_t))} = \fcolorbox{red}{white}{$\exp(Q(s_t, a_t) - V(s_t))$}. \tag{2}

Tying back to RL

Exponent of the advantage

The value exp(Q(st,at)V(st))\exp(Q(s_t, a_t) - V(s_t)) in (2) is the exponent of the advantage of action ata_t over the baseline value V(st)V(s_t). In the context of RL, the advantage is a measure of how much better or worse an action is compared to the average performance of the policy from a given state. It is used to evaluate the relative value of an action within a specific state. The advantage function is defined as:

A(st,at)=Q(st,at)V(st), A(s_t, a_t) = Q(s_t, a_t) - V(s_t),

By directly comparing Q(st,at)Q(s_t, a_t) to V(st)V(s_t), the algorithm prioritizes actions that lead to higher rewards relative to the average, making it more efficient in learning an optimal policy. Advantage provides a balance between exploring actions with potential (positive advantage) and avoiding suboptimal actions (negative advantage).

The exponent of advantage exp(Q(st,at)V(st))\exp(Q(s_t, a_t) - V(s_t)) converts the advantage into a positive, weighted score where actions with higher advantages contribute exponentially more to the total score. The term exp(Q(st,at))\exp(Q(s_t, a_t)) alone might grow very large, but subtracting V(st)V(s_t) normalizes the scale relative to the value of the state.

Then when normalized, the exponential term becomes part of a probability distribution. This is the softmax function over Q(st,at)Q(s_t, a_t) values for all actions ata_t, which assigns higher probabilities to actions with higher Q(st,at)Q(s_t, a_t):

π(atst)=exp(Q(st,at)V(st))aexp(Q(st,a)). \pi(a_t \mid s_t) = \frac{\exp(Q(s_t, a_t) - V(s_t))}{\sum_{a'} \exp(Q(s_t, a'))}.

Connecting to Maximum Entropy Reinforcement Learning

To jump to the conclusion, the optimal action conditional from our PGM equates to the optimal policy π(atst)\pi^\star(a_t \mid s_t) in the maximum entropy reinforcement learning framework [3] :

π(atst)=exp(Q(st,at)V(st)), \pi^\star(a_t \mid s_t) = \exp(Q^\star(s_t, a_t) - V^\star(s_t)),

where:

In maximum entropy reinforcement learning, the goal is to maximize both the expected reward and the entropy of the policy:

maxπE[t=1Tr(st,at)+αH(π(st))], \max_\pi \mathbb{E} \left[ \sum_{t=1}^T r(s_t, a_t) + \alpha H(\pi(\cdot \mid s_t)) \right],

where:

Deriving the entropy maximizing policy

[4] shows how soft policy iteration, involving the exponent of the advantage reaches a policy that maximizes expected rewards and entropy. [5] also shows an optimization-based approximate inference approach; variational inference, to approximate p(atst,Ot:T)p(a_t \mid s_t, O_{t:T}) from our PGM as exact inference of the optimal action conditional is intractable:

p(atst,Ot:T)p(Ot:Tst,at)p(Ot:Tst). p(a_t \mid s_t, O_{t:T}) \propto \textcolor{red}{\frac{p(O_{t:T} \mid s_t, a_t)}{p(O_{t:T} \mid s_t)}}.

Both terms in the numerator and denominator of the right hand side involves rollout of all possible sequences states and actions, which grows exponentially with the time horizon T integrating over all possible trajectories, which is computationally infeasible.

To employ variational inference to this problem, the goal is to fit a parameterized policy π(atst)\pi(a_t \mid s_t) such that the trajectory distribution

p^(τ)1[p(τ)0]t=1Tπ(atst) \hat{p}(\tau) \propto \mathbb{1}[p(\tau) \neq 0] \prod_{t=1}^T \pi(a_t \mid s_t)

matches the distribution (restricted to deterministic dynamics):

p(τo1:T)1[p(τ)0]exp(t=1Tr(st,at)). p(\tau \mid o_{1:T}) \propto \mathbb{1}[p(\tau) \neq 0] \exp \left( \sum_{t=1}^T r(s_t, a_t) \right).

We can therefore view the inference process as minimizing the KL divergence between these two trajectory distributions, which is given by:

DKL(p^(τ)p(τ))=Eτp^(τ)[logp(τ)logp^(τ)]. D_{\text{KL}}(\hat{p}(\tau) \parallel p(\tau)) = -\mathbb{E}_{\tau \sim \hat{p}(\tau)} \left[ \log p(\tau) - \log \hat{p}(\tau) \right]. Negating both sides and substituting in the equations for p(τ)p(\tau) and p^(τ)\hat{p}(\tau), we get

DKL(p^(τ)p(τ))=Eτp^(τ)[logp(s1)+t=1T(logp(st+1st,at)+r(st,at))logp(s1)t=1T(logp(st+1st,at)+logπ(atst))]=Eτp^(τ)[t=1T(r(st,at)logπ(atst))]=t=1TE(st,at)p^(st,at)[r(st,at)logπ(atst)]=t=1TE(st,at)p^(st,at)[r(st,at)]+Estp^(st)[H(π(atst))]. \begin{aligned} -D_{\text{KL}}(\hat{p}(\tau) \parallel p(\tau)) &= \mathbb{E}_{\tau \sim \hat{p}(\tau)} \Bigg[ \log p(s_1) + \sum_{t=1}^T \big( \log p(s_{t+1} \mid s_t, a_t) + r(s_t, a_t) \big) \\ &\quad \quad - \log p(s_1) - \sum_{t=1}^T \big( \log p(s_{t+1} \mid s_t, a_t) + \log \pi(a_t \mid s_t) \big) \Bigg] \\ &= \mathbb{E}_{\tau \sim \hat{p}(\tau)} \left[ \sum_{t=1}^T \big( r(s_t, a_t) - \log \pi(a_t \mid s_t) \big) \right] \\ &= \sum_{t=1}^T \mathbb{E}_{(s_t, a_t) \sim \hat{p}(s_t, a_t)} \big[ r(s_t, a_t) - \log \pi(a_t \mid s_t) \big] \\ &= \sum_{t=1}^T \mathbb{E}_{(s_t, a_t) \sim \hat{p}(s_t, a_t)} \fcolorbox{red}{white}{$\big[ r(s_t, a_t) \big]$} + \mathbb{E}_{s_t \sim \hat{p}(s_t)} \fcolorbox{blue}{white}{$\big[ H(\pi(a_t \mid s_t)) \big]$}. \end{aligned}

This shows that minimizing the KL divergence corresponds to maximizing the expected reward and the expected conditional entropy. As we hinted from early on, this is different from the standard RL objective where we only maximizes reward, and is thus separately referred to as maximum entropy reinforcement learning.

Conclusion

We have so far walked through formulating an RL problem into one that can be solved with probabilistic inference to reach to a special variation of RL; maximum entropy reinforcement learning. In the bigger picture, the extensibility and compositionality of graphical models can likely be leveraged to produce more sophisticated reinforcement learning methods, and the framework of probabilistic inference can offer a powerful toolkit for deriving effective and convergent learning algorithms for the corresponding models.

A particularly exciting recent development is the intersection of maximum entropy reinforcement learning and latent variable models, where the graphical model for control as inference is augmented with additional variables for modeling time-correlated stochasticity for exploration [6], [7] or higher-level control through learned latent action spaces [8], [9].

References

  1. “Hidden Markov Model,” Wikipedia, 2024.
  2. M. I. Jordan, R. Diankov, and X. Liu, “Sum-Product, Max A Posteriori, Bayesians and Frequentists (9/16/04).”
  3. B. D. Ziebart, A. Maas, J. A. Bagnell, and A. K. Dey, “Maximum Entropy Inverse Reinforcement Learning.”
  4. T. Haarnoja, A. Zhou, P. Abbeel, and S. Levine, “Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,” 2018, doi: 10.48550/arXiv.1801.01290.
  5. S. Levine, “Reinforcement Learning and Control as Probabilistic Inference: Tutorial and Review,” 2018.
  6. K. Hausman, J. T. Springenberg, Z. Wang, N. Heess, and M. Riedmiller, “Learning an Embedding Space for Transferable Robot Skills,” in International Conference on Learning Representations, 2018.
  7. A. Gupta, R. Mendonca, Y. X. Liu, P. Abbeel, and S. Levine, “Meta-Reinforcement Learning of Structured Exploration Strategies,” 2018, doi: 10.48550/arXiv.1802.07245.
  8. T. Haarnoja, H. Tang, P. Abbeel, and S. Levine, “Reinforcement Learning with Deep Energy-Based Policies,” in Proceedings of the 34th International Conference on Machine Learning, PMLR, 2017, pp. 1352–1361.
  9. T. Haarnoja, K. Hartikainen, P. Abbeel, and S. Levine, “Latent Space Policies for Hierarchical Reinforcement Learning,” 2018, doi: 10.48550/arXiv.1804.02808.