{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3 (ipykernel)", "language": "python" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "metadata": { "id": "nDwiVvwWi2-i", "ExecuteTime": { "end_time": "2025-10-22T11:32:36.484507Z", "start_time": "2025-10-22T11:32:36.013055Z" } }, "source": [ "import pytensor\n", "import pytensor.tensor as pt\n", "\n", "from pytensor.compile.mode import get_default_mode" ], "outputs": [], "execution_count": 1 }, { "cell_type": "markdown", "source": [ "## Differentiating loops\n", "\n", "Let's write a trivial numerical loop" ], "metadata": { "id": "pr6926Q7i31E" } }, { "cell_type": "code", "source": [ "x0 = 0.95\n", "x = x0\n", "for i in range(4):\n", " x = x ** 2\n", "x4 = x\n", "x4" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Apjcgkkpi7iU", "outputId": "a50d1c34-4b76-4513-d6a6-143befe5f3ad", "ExecuteTime": { "end_time": "2025-10-22T11:32:36.504316Z", "start_time": "2025-10-22T11:32:36.494302Z" } }, "outputs": [ { "data": { "text/plain": [ "0.44012666865176564" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 2 }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "Rr9_Zd92nvfc" } }, { "cell_type": "markdown", "source": [ "*Blue squares denote inputs, red outputs, and white, intermediate quantities*" ], "metadata": { "id": "iQ7MSCfSLq_F" } }, { "cell_type": "markdown", "source": [ "How does the final x4 change if we perturb ever so slightly x0. In other words, what is the derivative of the x4 with respect to x0?\n", "\n", "Analytically, this function is equivalent to `x4 = x0 ** 2 ** 4 == x0 ** 16`, so the derivative is `16 * (x0 ** 15)`" ], "metadata": { "id": "Um8OSdLtjXfg" } }, { "cell_type": "code", "source": [ "x0 ** 16" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uUlvMd5ijP-e", "outputId": "536fab0d-78ee-4d93-b154-73a77de999f7", "ExecuteTime": { "end_time": "2025-10-22T11:32:36.550176Z", "start_time": "2025-10-22T11:32:36.545941Z" } }, "outputs": [ { "data": { "text/plain": [ "0.44012666865176536" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 3 }, { "cell_type": "code", "source": [ "16 * x0 ** 15" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YYwYOTjOjxb1", "outputId": "a292f501-764c-433e-c70d-9b6d17c5dec2", "ExecuteTime": { "end_time": "2025-10-22T11:32:36.601961Z", "start_time": "2025-10-22T11:32:36.597746Z" } }, "outputs": [ { "data": { "text/plain": [ "7.412659682556049" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 4 }, { "cell_type": "markdown", "source": [ "But we're interested in automatic differentiation, not symbolic differentiation. We want an algorithm that works regardless of the update expression in each step of the loop.\n", "\n", "And in particular we want to perform reverse-mode automatic differentiation, which is the most used format for machine learning.\n", "\n", "To get some intuition, we can unroll the loop and differentiate each step in reverse, accumulating the partial derivatives at each step. This is a generic application of the chain rule." ], "metadata": { "id": "CII-p2DlkAI2" } }, { "cell_type": "code", "source": [ "x1 = x0 ** 2\n", "x2 = x1 ** 2\n", "x3 = x2 ** 2\n", "x4 = x3 ** 2" ], "metadata": { "id": "U8Aq8zR2j84d", "ExecuteTime": { "end_time": "2025-10-22T11:32:36.652939Z", "start_time": "2025-10-22T11:32:36.649935Z" } }, "outputs": [], "execution_count": 5 }, { "cell_type": "code", "source": [ "dout_dx4 = 1.0 # They're the same\n", "dout_dx3 = dout_dx4 * 2 * x3\n", "dout_dx2 = dout_dx3 * 2 * x2\n", "dout_dx1 = dout_dx2 * 2 * x1\n", "dout_dx0 = dout_dx1 * 2 * x0\n", "dout_dx0" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "XglvLv7akc77", "outputId": "052bb55e-2132-4aab-caac-206703e0bfc5", "ExecuteTime": { "end_time": "2025-10-22T11:32:36.703849Z", "start_time": "2025-10-22T11:32:36.699513Z" } }, "outputs": [ { "data": { "text/plain": [ "7.412659682556052" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 6 }, { "cell_type": "markdown", "source": [ "The general pattern is to differentiate the step function `x ** 2` and multiply (or dot in the case of multivariate gradients) with the gradient from the \"future\" step.\n", "\n", "Note that we reused the intermediate steps of the original loop. This is usually the case, so it may be useful if our looping constructs can store these results for the gradient calculation. Programmatically, our approach to differentiate this kind of loop looks something like:" ], "metadata": { "id": "CkBxu8Ebmk6L" } }, { "cell_type": "code", "source": [ "xtm1 = pt.scalar(\"xtm1\")\n", "xt = xtm1 ** 2\n", "\n", "dout_dxt = pt.scalar(\"dout_dxt\")\n", "dout_dxtm1 = pytensor.gradient.Lop(xt, wrt=xtm1, eval_points=dout_dxt)\n", "\n", "step_fn = pytensor.function([xtm1], xt)\n", "d_step_fn = pytensor.function([xtm1, dout_dxt], dout_dxtm1)" ], "metadata": { "id": "ISlEF5-wmiA6", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.269882Z", "start_time": "2025-10-22T11:32:36.748970Z" } }, "outputs": [], "execution_count": 7 }, { "cell_type": "code", "source": [ "xs = [x0]\n", "for i in range(4):\n", " xs.append(float(step_fn(xs[-1])))\n", "xs" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ek8jruKGmgiq", "outputId": "c52010ab-60a2-4dd5-f2e1-baffa173ee2c", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.283935Z", "start_time": "2025-10-22T11:32:37.278620Z" } }, "outputs": [ { "data": { "text/plain": [ "[0.95, 0.9025, 0.81450625, 0.6634204312890625, 0.44012666865176564]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 8 }, { "cell_type": "code", "source": [ "dout_dx4 = 1.0 # They are the same\n", "dout_dxt = dout_dx4\n", "for i in reversed(range(4)):\n", " dout_dxt = float(d_step_fn(xs[i], dout_dxt))\n", "dout_dx0 = dout_dxt\n", "dout_dx0" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "k_lzDqHIknM9", "outputId": "498aba89-866e-4084-9edf-98be87c6ac5f", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.336706Z", "start_time": "2025-10-22T11:32:37.331907Z" } }, "outputs": [ { "data": { "text/plain": [ "7.412659682556052" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 9 }, { "cell_type": "markdown", "source": [ "Unsurprisingly, to differentiate a loop operation we just need another loop operation. If we are clever about how we represent it, we can reuse our regular automatic-differentiation machinery..." ], "metadata": { "id": "k-iRUPU9pgtH" } }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "aXvnsuFaTXV5" } }, { "cell_type": "markdown", "source": "*Blue squares denote inputs or gradient with respect to inputs, red squares outputs or output gradients, and white squares denote intermediate quantities. Washed out squares indicate variables not needed for the final computation*", "metadata": { "id": "Sv9S_IUxLvup" } }, { "cell_type": "markdown", "source": [ "This approach works, but it introduces a new loop structure. The original loop's state `x` depended only on its previous state `xtm1`. The gradient loop's state `dout_dxt` depends on both its previous state `dout_dxtm1` and an external sequence `xs[i]`.\n", "\n", "This difference becomes important for higher-order derivatives. If we differentiate this gradient loop again (for the second derivative), our differentiation machinery must handle loops that read from an external sequence.\n", "\n", "Our step function looks like `dout_dxt = 2 * xs[i] * dout_dxtm1`.\n", "\n", "The chain rule doesn't care about the semantics of the input variables, so let's rename them and focus on the essentials:\n", "- `z` for the loop state (the accumulating gradient, `dout_dxt`).\n", "- `y` for the input sequence (`xs`)." ], "metadata": { "id": "nbieh6-tTV0C" } }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "5XVhoTD8wtst" } }, { "cell_type": "markdown", "source": [ "Each step is a function of two variables, of the form `zt = 2 * y[tm1] * ztm1`\n", "\n", "Because we are applying the chain rule, we don't have to know anything about the relationship between each `y`. They are just a sequence of inputs to this specific loop. Accordingly, we draw no arrows between them in the diagram.\n", "\n", "To stay general, we need a way to compute the partial derivatives of the loop outcome with respect to each intermediate `y[tm1]`. Let's give it a more general name: `dout_dyt`.\n", "\n", "We can use the same strategy and reason about the unrolled loop (using our new abstract variable names)" ], "metadata": { "id": "6FoB2kgvwvRX" } }, { "cell_type": "code", "source": [ "y = xs[:-1][::-1]\n", "\n", "z0 = 1.0\n", "z1 = 2 * y[0] * z0\n", "z2 = 2 * y[1] * z1\n", "z3 = 2 * y[2] * z2\n", "z4 = 2 * y[3] * z3\n", "zT = z4\n", "zT" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QMYf9qOKtBGs", "outputId": "ae21d887-9d84-44b8-e1bc-48f31b93996d", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.391210Z", "start_time": "2025-10-22T11:32:37.386353Z" } }, "outputs": [ { "data": { "text/plain": [ "7.412659682556052" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 10 }, { "cell_type": "markdown", "source": [ "We need two quantities. First, the gradient of each intermediate step with respect to the sequence value read on that step.\n", "\n", "`zt = 2 * y[tm1] * ztm1`\n", "\n", "`dzt/dy[t] = 2 * ztm1`\n", "\n", "Again note that we can (and probably want to) reuse the intermediate steps `zt` from the loop we're differentiating. This is a common motif." ], "metadata": { "id": "p1VvOfFhwRLr" } }, { "cell_type": "code", "source": [ "dz4_dy3 = 2 * z3\n", "dz3_dy2 = 2 * z2\n", "dz2_dy1 = 2 * z1\n", "dz1_dy0 = 2 * z0" ], "metadata": { "id": "5i1TO65ywiLZ", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.439015Z", "start_time": "2025-10-22T11:32:37.436394Z" } }, "outputs": [], "execution_count": 11 }, { "cell_type": "markdown", "source": [ "And second, we need the the gradient of each step with respect to the previous step\n", "\n", "`dzt/dztm1 = 2 * y[tm1]`" ], "metadata": { "id": "FkGdLBflwkSp" } }, { "cell_type": "code", "source": [ "dout_dz4 = 1.0 # They're the same\n", "dz4_dz3 = 2 * y[3]\n", "dz3_dz2 = 2 * y[2]\n", "dz2_dz1 = 2 * y[1]\n", "dz1_dz0 = 2 * y[0]" ], "metadata": { "id": "PXscOKT1wPoN", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.507470Z", "start_time": "2025-10-22T11:32:37.504085Z" } }, "outputs": [], "execution_count": 12 }, { "cell_type": "markdown", "source": [ "By the chain rule, the gradient of the final `z` with respect to each indexed sequence is `dout_dyt = dout_dzt * dzt_dyt`.\n", "\n", "In the snippet below we don't accumulate the gradient `dout_dzt` to emphasize how this quantity is generated recursively. This is clearly wasteful!" ], "metadata": { "id": "s2IQmP_DxSSU" } }, { "cell_type": "code", "source": [ "dout_dy3 = dout_dz4 * dz4_dy3\n", "dout_dy2 = dout_dz4 * dz4_dz3 * dz3_dy2\n", "dout_dy1 = dout_dz4 * dz4_dz3 * dz3_dz2 * dz2_dy1\n", "dout_dy0 = dout_dz4 * dz4_dz3 * dz3_dz2 * dz2_dz1 * dz1_dy0\n", "(dout_dy3, dout_dy2, dout_dy1, dout_dy0)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HTq8uHODsqSC", "outputId": "004a54db-b6d9-44a2-ff3e-90ef4a8ace62", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.574931Z", "start_time": "2025-10-22T11:32:37.570384Z" } }, "outputs": [ { "data": { "text/plain": [ "(7.802799665848476, 8.21347333247208, 9.100801476423358, 11.1733967375)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 13 }, { "cell_type": "markdown", "source": [ "This partial derivative is less intuitive to verify analytically, so we'll compare with finite differences." ], "metadata": { "id": "23qaWquvwO32" } }, { "cell_type": "code", "source": [ "# Wrap the gradient loop in a function\n", "def foreach_loop(x0, ys):\n", " x = x0\n", " for y in ys:\n", " x = 2 * x * y\n", " return x\n", "\n", "# Confirm we get the same results\n", "ys = xs[:-1][::-1]\n", "\n", "foreach_loop(1.0, ys)" ], "metadata": { "id": "sp_apcKDwJqb", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "2f5c0a58-400f-40d9-89d0-556194502de1", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.634865Z", "start_time": "2025-10-22T11:32:37.630153Z" } }, "outputs": [ { "data": { "text/plain": [ "7.412659682556052" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 14 }, { "cell_type": "code", "source": [ "# Perform finite differences on each y[t]\n", "def finite_differences(ys, t, eps=1e-5):\n", " ysp, ysm = ys.copy(), ys.copy()\n", " ysp[t] += eps\n", " ysm[t] -= eps\n", " rp = foreach_loop(1.0, ysp)\n", " rm = foreach_loop(1.0, ysm)\n", " return (rp - rm) / (2 * eps)\n", "\n", "[finite_differences(ys, idx) for idx in reversed(range(4))]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "XUF0rrd9pMRk", "outputId": "71c8fccd-d057-4213-f5a7-ef20f978a129", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.687418Z", "start_time": "2025-10-22T11:32:37.682248Z" } }, "outputs": [ { "data": { "text/plain": [ "[7.802799665812542, 8.213473332396859, 9.100801476336073, 11.173396737396144]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 15 }, { "cell_type": "markdown", "source": [ "The results match.\n", "\n", "With the risk of over-doing it, here is the graphical representation of our new reverse-mode autodiff loop. Note we generate two quantities as we iterate: the gradient with respect to each sequence and the previous state." ], "metadata": { "id": "KuHkEoJd1AFg" } }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "4hP6Upql05Xp" } }, { "cell_type": "markdown", "source": "Let's recall what our abstract variable names mean for our example:", "metadata": { "id": "Sbxzuz-v1de2" } }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "_4PCrGns1hKA" } }, { "cell_type": "markdown", "source": [ "This gives us the partial second derivatives with respect to each intermediate step of the original forward loop (the fourth row of boxes). We now need to propagate them back to the initial point x0.\n", "\n", "This is also something slightly new. We have differentiated the final state of a loop with respect to the starting state, and we have differentiated the final state of a loop with respect to a sequence of values **consumed** in a loop. Now we need to differentiate a sequence of values **produced** in a loop with respect to the initial state.\n", "\n", "Let's look back at our first gradient loop scheme:" ], "metadata": { "id": "w5yY5tAy-XYN" } }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "Vdyc2Ge1W4B_" } }, { "cell_type": "markdown", "source": "We can re-express it as follows:", "metadata": { "id": "kflY59ASW89c" } }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "j39bNhIMXYNE" } }, { "cell_type": "markdown", "source": [ "Where at each intermediate step we accumulate the total gradient coming from the next step (the horizontal ← arrow), which represents the effect this variable had in the subsequent iterations of the loop, and the partial gradient coming from it's direct use (the vertical ↑ arrow).\n", "\n", "It just so happened that in the first derivative example, the only \"external\" gradient contribution was at the very last step (`dout_dx4 = 1.0`).\n", "\n", "Now we have the reversed scenario. All but the last step have external contributions, the partial second derivatives with respect to each intermediate `x` of the original forward loop." ], "metadata": { "id": "zXBYGjxQXbZA" } }, { "cell_type": "markdown", "source": [ "![image.png]()" ], "metadata": { "id": "Njk2MvnHXKH7" } }, { "cell_type": "code", "source": [ "dout_dx = [dout_dy3, dout_dy2, dout_dy1, dout_dy0, 0]\n", "for i in reversed(range(4)):\n", " dout_dx[i] += float(d_step_fn(xs[i], dout_dx[i+1]))\n", "dout_dx0 = dout_dx[0]\n", "dout_dx0" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4CRAB76GXItK", "outputId": "8b091b76-2929-48ac-b8ce-c87e66546177", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.737068Z", "start_time": "2025-10-22T11:32:37.731833Z" } }, "outputs": [ { "data": { "text/plain": [ "117.04199498772715" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 16 }, { "cell_type": "markdown", "source": "This value we can check easily against the symbolic second derivative", "metadata": { "id": "ZbvAhvyWpvnE" } }, { "cell_type": "code", "source": [ "16 * 15 * x0 ** 14" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "V0BpI1BQW52n", "outputId": "3f55971f-bc34-4e1b-9765-50a1da664c8b", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.788679Z", "start_time": "2025-10-22T11:32:37.784973Z" } }, "outputs": [ { "data": { "text/plain": [ "117.0419949877271" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 17 }, { "cell_type": "markdown", "source": [ "### General strategy: A tape to accumulate partial derivatives as we iterate" ], "metadata": { "id": "Q0_7yyzVwtKX" } }, { "cell_type": "markdown", "source": "Maintaining a read/write tape of accumulated partials per step is the core of loop autodiff in PyTensor.", "metadata": { "id": "Od4o8rf9w3r0" } }, { "cell_type": "markdown", "source": "We can represent the first gradient loop with the same general approach. This time let's be smart and remember to store the intermediate values.", "metadata": { "id": "POuU-J8CqndT" } }, { "cell_type": "code", "source": [ "dout_dx = [0, 0, 0, 0, 1]\n", "for i in reversed(range(4)):\n", " dout_dx[i] += float(d_step_fn(xs[i], dout_dx[i+1]))\n", "dout_dx0 = dout_dx[0]\n", "dout_dx0" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fkIPOh12B50T", "outputId": "876dec54-2f6d-4da2-a049-14b09165d1d4", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.843038Z", "start_time": "2025-10-22T11:32:37.838302Z" } }, "outputs": [ { "data": { "text/plain": [ "7.412659682556052" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 18 }, { "cell_type": "markdown", "source": [ "It can also be used for the second gradient loop. In our quick and dirty implementation above we computed 2 intermediate quantities separately `dout_dyt` and `dzt_dztm1`, and then combined them manually like this:" ], "metadata": { "id": "T8xFcmcDptiA" } }, { "cell_type": "code", "source": [ "dout_dy3 = dout_dz4 * dz4_dy3\n", "dout_dy2 = dout_dz4 * dz4_dz3 * dz3_dy2\n", "dout_dy1 = dout_dz4 * dz4_dz3 * dz3_dz2 * dz2_dy1\n", "dout_dy0 = dout_dz4 * dz4_dz3 * dz3_dz2 * dz2_dz1 * dz1_dy0\n", "(dout_dy3, dout_dy2, dout_dy1, dout_dy0)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qlN8UQlkp3nj", "outputId": "ee790d84-c292-490d-f5d6-ac45128fa14e", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.897019Z", "start_time": "2025-10-22T11:32:37.891867Z" } }, "outputs": [ { "data": { "text/plain": [ "(7.802799665848476, 8.21347333247208, 9.100801476423358, 11.1733967375)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 19 }, { "cell_type": "markdown", "source": [ "Now we'll do it all in a single pass, using our autodiff machinery" ], "metadata": { "id": "ys0__CJMv-Tt" } }, { "cell_type": "code", "source": [ "ztm1 = pt.scalar(\"ztm1\")\n", "ytm1 = pt.scalar(\"ytm1\")\n", "\n", "zt = 2 * ytm1 * ztm1\n", "\n", "dout_dzt = pt.scalar(\"dout_dzt\")\n", "dout_dztm1 = pytensor.gradient.Lop(zt, wrt=ztm1, eval_points=dout_dzt)\n", "dout_dytm1 = pytensor.gradient.Lop(zt, wrt=ytm1, eval_points=dout_dzt)\n", "\n", "d_dzt_step_fn = pytensor.function([ytm1, dout_dzt], dout_dztm1)\n", "d_dyt_step_fn = pytensor.function([ztm1, dout_dzt], dout_dytm1)" ], "metadata": { "id": "Y2KYE5X8cpwc", "ExecuteTime": { "end_time": "2025-10-22T11:32:37.971500Z", "start_time": "2025-10-22T11:32:37.943474Z" } }, "outputs": [], "execution_count": 20 }, { "cell_type": "code", "source": [ "# Inputs to our loop\n", "# z: the reversed forward steps consumed by the first auto-diff loop (reversed once more)\n", "z = y[::-1]\n", "# w: the intermediate steps produced by the first auto-diff loop (already reversed,\n", "# ignoring the last entry which isn't needed)\n", "w = dout_dx[1:]\n", "# partial derivatives of each intermediate step of the second auto-diff loop,\n", "# only first variable is connected to the cost\n", "dout_dzt = [1, 0, 0, 0, 0]\n", "# partial derivatives of the second auto-diff loop with respect to each read sequence\n", "dout_dyt = []\n", "for i in range(4):\n", " # Accumulate the partial derivative with respect to the previous step\n", " dout_dzt[i+1] = float(d_dzt_step_fn(z[i], dout_dzt[i]))\n", " # Output the partial derivative with respect to the read sequence\n", " dout_dyt.append(float(d_dyt_step_fn(w[i], dout_dzt[i])))\n", "dout_dyt\n", "[7.802799665848476, 8.21347333247208, 9.100801476423358, 11.1733967375]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YPrlA67_qtkM", "outputId": "378daf76-17ec-4f74-bd2a-1da9e3a67e68", "ExecuteTime": { "end_time": "2025-10-22T11:32:38.008869Z", "start_time": "2025-10-22T11:32:38.002427Z" } }, "outputs": [ { "data": { "text/plain": [ "[7.802799665848476, 8.21347333247208, 9.100801476423358, 11.1733967375]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 21 }, { "cell_type": "markdown", "source": [ "Although not always needed, this approach is the most general. We've seen how it can be used to accumulate gradients where only the last state of the loop is connected to the cost, or where all but the last state are connected. In fact, we can have arbitrary dense/sparse/uneven connection between intermediate steps and the final function cost.\n", "\n", "The only thing we haven't discussed is how to explicitly differentiate over the operation of updating such a tape. We won't show it here, but the same tape procedure can be used for it, meaning we have a closed auto-diff system (i.e., we can keep doing higher order differentiation using the exesting procedures)" ], "metadata": { "id": "2_Oapol0yWyH" } }, { "cell_type": "markdown", "source": [ "## PyTensor Scan\n", "\n", "PyTensor uses Scan to represent symbolic loops. For reasons that are now clear, it automatically stores a tape of initial, intermediate results, and final results as iterates. It masks the initial results before returning the sequence.\n", "\n", "When a user is only interested in the final element (or any subset), they should index the returned sequence. Differentiating through this indexing operation creates a mask of zeros for the unused values, and the gradient for used ones. This is the very tape that will be updated by the reverse Scan.\n", "\n", "Use `dprint` on the function below to identify these elements and the four `Scan` instances derived in the notebook." ], "metadata": { "id": "ovGLkQK9woro" } }, { "cell_type": "code", "source": [ "x0 = pt.scalar(\"x0\")\n", "xs, _ = pytensor.scan(\n", " fn = lambda xtm1: xtm1 ** 2,\n", " outputs_info=[x0],\n", " n_steps=4,\n", " # for readability\n", " mode=get_default_mode().excluding(\"fusion\")\n", ")\n", "out = xs[-1]\n", "g = pt.grad(out, wrt=x0)\n", "h = pt.grad(g, wrt=x0)\n", "\n", "# for readability\n", "mode = get_default_mode().excluding(\"scan_pushout\")\n", "h_fn = pytensor.function([x0], h, mode=mode)\n", "# Uncomment line below\n", "# h_fn.dprint(print_shape=True)" ], "metadata": { "id": "xUim5ysDQLgp", "ExecuteTime": { "end_time": "2025-10-22T11:32:40.034337Z", "start_time": "2025-10-22T11:32:38.055710Z" } }, "outputs": [], "execution_count": 22 } ] }