SoFunction
Updated on 2025-03-02

Detailed explanation of the usage of retain_graph in Pytorch

Usage analysis

When viewing the SRGAN source code, there is the following loss function, where retain_graph=True is set. What is its function?

		############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if .is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if .is_available():
      z = ()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    ()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    ()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

Retain_graph=True is used during the loss backpropagation process when updating the D network. The purpose is to retain the gradient calculated in the process and to use it during subsequent G network updates;

In fact, we cannot use the retain_graph parameter in normal times, but in special cases we will use it.

The following code:

import torch
y=x**2
z=y*4
output1=()
output2=()
()
()

The following error message is output:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 ()
   2 ()

D:\ProgramData\Anaconda3\lib\site-packages\torch\ in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     (self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Modify it to the following correct:

import torch
y=x**2
z=y*4
output1=()
output2=()
(retain_graph=True)
()
# If you have two Loss, first execute the first backward, then execute the second backward(retain_graph=True)
() # After executing this, all intermediate variables will be released for the next loop() # Update parameters

Variable class source code

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: Any type of encapsulated tensor。
    grad: Save anddataGradients matching type and position,This property is difficult to allocate and cannot be reassigned。
    requires_grad: Tag whether the variable has been created by a subgraph that needs to be called to this variableboolvalue。Only modify it on leaf variables。
    volatile: Whether tag variables can be applied in inference mode(If history is not saved)ofboolvalue。Changes only on leaf variables。
    is_leaf: Whether the tag variable is a graph leaf(如由用户创建ofvariable)ofboolvalue.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包装of张量.
    requires_grad (bool): bool型of标记value. **Keyword only.**
    volatile (bool): bool型of标记value. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """计算关于当前图叶子variableof梯度,Figures use chain rules to lead to differentiation
    ifVariableIt's a scalar(For example, it contains a single element data),You don't need to be rightbackward()Specify any parameters
    ifvariable不yes标量(包含多个元素数据of矢量)And a gradient is required,函数需要额外of梯度;
    Need to specify a andtensorof形状匹配ofgrad_outputparameter(yProjecting pairs in specified directionxof导数);
    可以yes一个类型和位置相匹配且包含与自身相关of不同函数梯度of张量。
    Functions accumulate gradients on leaves,The leaf needs to be cleared before calling。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              variableof梯度,ifyes一个张量,unless“create_graph”yesTrue,Otherwise it will be automatically converted tovolatile型ofvariable。
              Can be scalar variables or not requiredgradofvalue指定Nonevalue。ifNonevalue可接受,则此parameter可选。
      retain_graph (bool, optional): ifforFalse,用来计算梯度of图将被释放。
                      In almost all cases,Set this option toTrue不yes必需of,通常可以以更有效of方式解决。
                      默认valueforcreate_graphofvalue。
      create_graph (bool, optional): forTruehour,会构造一个导数of图,Used to calculate higher order derivative results。
                      默认forFalse,unless``gradient``yes一个volatilevariable。
    """
    (self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    Whenevervariable相关of梯度被计算hour调用hook,hookof申明:hook(grad)-&gt;Variable or None
    Not righthookofparameter进行修改,但可以选择性地返回一个新of梯度以用在`grad`of相应位置。
 
    The function returns ahandle,That``()``Methods are used tohookRemove from module。
 
    Example:
      &gt;&gt;&gt; v = Variable(([0, 0, 0]), requires_grad=True)
      &gt;&gt;&gt; h = v.register_hook(lambda grad: grad * 2) # double the gradient
      &gt;&gt;&gt; (([1, 1, 1]))
      &gt;&gt;&gt; 
       2
       2
       2
      [ of size 3]
      &gt;&gt;&gt; () # removes the hook
    """
    if :
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = (self._backward_hooks)
    self._backward_hooks[] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    区分随机节点需要for他们提供rewardvalue。if图表中包含任何of随机操作,都应该在That输出上调用此函数,Otherwise, an error will occur。
    Parameters:
      reward(Tensor): 带有每个元素奖赏of张量,Must be withVariable数据of设备位置和形状相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """Returns a heart variable separated from the current graph.
     The result does not require a gradient, and if the input is volatile, the output is also volatile.
 
     .. Notice::
      The return variable uses the same data tensor as the original variable, and can see in-place modifications of either of them and may trigger an error in the correctness check.
     """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """Separate variables from the graph where it was created and serve as a leaf for that graph"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = (self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = ()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

The detailed explanation of the usage of retain_graph in Pytorch is all the content I have shared with you. I hope you can give you a reference and I hope you can support me more.