飙血推荐
  • HTML教程
  • MySQL教程
  • JavaScript基础教程
  • php入门教程
  • JavaScript正则表达式运用
  • Excel函数教程
  • AngularJS教程
  • UEditor使用文档
  • ThinkPHP5.0教程

[源码解析] PyTorch 分布式 Autograd (6) ---- 引擎(下)

时间:2021-12-07  作者:rossiXYZ  
上文我们介绍了引擎如何获得后向计算图的依赖,本文我们就接着看看引擎如何依据这些依赖进行后向传播。

[源码解析] PyTtorch 分布式 Autograd (6) ---- 引擎(下)

目录
  • [源码解析] PyTtorch 分布式 Autograd (6) ---- 引擎(下)
    • 0x00 摘要
    • 0x01 回顾
    • 0x02 执行GraphTask
      • 2.1 runEngineAndAccumulateGradients
      • 2.2 execute_graph_task_until_ready_queue_empty
      • 2.3 evaluate_function
      • 2.4 globalCpuThread
      • 2.5 小结
    • 0x03 RPC调用
      • 3.1 RecvRpcBackward
        • 3.1.1 定义
        • 3.1.2 构建
        • 3.1.3 apply
      • 3.2 PropagateGradientsReq
        • 3.2.1 定义
      • 3.3 接受方
        • 3.3.1 接受消息
        • 3.3.2 processBackwardAutogradReq
        • 3.3.3 executeSendFunctionAsync
    • 0x04 DistAccumulateGradCaptureHook
      • 4.1 定义
      • 4.2 生成
      • 4.3 使用
      • 4.4 累积梯度
        • 4.4.1 上下文累积
        • 4.4.2 算子累积
    • 0x05 等待完成
    • 0xFF 参考

0x00 摘要

上文我们介绍了引擎如何获得后向计算图的依赖,本文我们就接着看看引擎如何依据这些依赖进行后向传播。通过本文的学习,大家可以:

  • 了解 RecvRpcBackward 如何给对应的下游节点发送 RPC 消息,可以再次梳理一下worker之间后向传播的交互流程。
  • 了解 AccumulateGrad 如何在上下文累积梯度。

PyTorch分布式其他文章如下:

深度学习利器之自动微分(1)

深度学习利器之自动微分(2)

[源码解析]深度学习利器之自动微分(3) --- 示例解读

[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

[源码解析] PyTorch如何实现前向传播(3) --- 具体实现

[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

[源码解析] PyTorch 分布式(1)------历史和概述

[源码解析] PyTorch 分布式(2) ----- DataParallel(上)

[源码解析] PyTorch 分布式(3) ----- DataParallel(下)

[源码解析] PyTorch 分布式(4)------分布式应用基础概念

[源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

[源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组

[源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇

[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

[源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构

[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作

[源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播

[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播

[源码解析] PyTorch 分布式 Autograd (1) ---- 设计

[源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础

[源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关

[源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

[源码解析] PyTorch 分布式 Autograd (5) ---- 引擎(上)

为了更好的说明,本文代码会依据具体情况来进行相应精简。

0x01 回顾

我们首先回顾FAST模式算法算法如下,本文需要讨论后面若干部分。

  1. 我们从具有反向传播根的worker开始(所有根都必须是本地的)。
  2. 查找当前Distributed Autograd Context 的所有send函数 。
  3. 从提供的根和我们检索到的所有send函数开始,我们在本地计算依赖项 。
  4. 计算依赖项后,使用提供的根来启动本地 autograd 引擎。
  5. 当 autograd 引擎执行该recv函数时,该recv 函数通过 RPC 将输入梯度发送到适当的worker。每个recv函数都知道目标 worker id,因为它被记录为前向传播的一部分。通过autograd_context_idautograd_message_idrecv函数被发送到远程主机。
  6. 当远程主机收到这个请求时,我们使用 autograd_context_idautograd_message_id来查找适当的send函数。
  7. 如果这是worker第一次收到对给定 autograd_context_id的请求,它将按照上面的第 1-3 点所述在本地计算依赖项。
  8. 然后将在第6点接受到的send方法插入队列,以便在该worker的本地 autograd 引擎上执行。
  9. 最后,我们不是在 Tensor的.grad之上累积梯度,而是在每个Distributed Autograd Context之上分别累积梯度 。梯度存储在Dict[Tensor, Tensor]之中 ,Dict[Tensor, Tensor]基本上是从 Tensor 到其关联梯度的映射,并且可以使用 get_gradients() API检索该映射 。

其次,我们看看总体执行代码,总体执行是在 DistEngine::execute 之中完成,具体分为如下步骤:

  • 使用 contextId 得到前向的上下文。
  • 使用 validateRootsAndRetrieveEdges 进行验证。
  • 构造一个GraphRoot,用它来驱动后向传播,可以认为是一个虚拟根。
  • 使用 computeDependencies 计算依赖。
  • 使用 runEngineAndAccumulateGradients 进行反向传播计算。
  • 使用 clearAndWaitForOutstandingRpcsAsync 等待 RPC 完成。
void DistEngine::execute(
    int64_t contextId,
    const variable_list& roots,
    bool retainGraph) {
  // Retrieve the context for the given context_id. This will throw if the
  // context_id is invalid.
  auto autogradContext =
      DistAutogradContainer::getInstance().retrieveContext(contextId);

  // Perform initial pre-processing.
  edge_list rootEdges;
  variable_list grads;
  validateRootsAndRetrieveEdges(roots, rootEdges, grads); 

  // 构造一个GraphRoot,用它来驱动后向传播,可以认为是一个虚拟根
  std::shared_ptr<Node> graphRoot =
      std::make_shared<GraphRoot>(rootEdges, grads);
  edge_list outputEdges;
  // Compute dependencies locally, starting from all roots and all \'send\'
  // functions.
  {
    std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
    // Context should not have been initialized already.
    TORCH_INTERNAL_ASSERT(
        域名(autogradContext->contextId()) ==
        域名());

    // 计算依赖
    computeDependencies(
        autogradContext, rootEdges, grads, graphRoot, outputEdges, retainGraph);

    // Mark the autograd context id as initialized.
    域名rt(autogradContext->contextId());
  }

  BackwardPassCleanupGuard guard(autogradContext);

  // This needs to be blocking and as a result we wait for the future to
  // complete.
  runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges)
      ->waitAndThrow(); // 反向传播计算

  // Wait for all of the outstanding rpcs to complete.
  autogradContext->clearAndWaitForOutstandingRpcsAsync()->waitAndThrow();
}

再次,从前文我们知道,依赖项已经在 computeDependencies 之中处理完毕,所有需要计算的函数信息都位于 域名_info_ 之上。我们接下来就看看如何计算,就是 runEngineAndAccumulateGradients 和 clearAndWaitForOutstandingRpcsAsync 这两个方法。

0x02 执行GraphTask

我们首先看看如何使用 runEngineAndAccumulateGradients 进行反向传播计算,累积梯度。

2.1 runEngineAndAccumulateGradients

引擎之中,首先调用了 runEngineAndAccumulateGradients。主要是封装了一个 NodeTask,然后以此调用 execute_graph_task_until_ready_queue_empty。其中使用 at::launch 来启动线程。

c10::intrusive_ptr<c10::ivalue::Future> DistEngine::
    runEngineAndAccumulateGradients(
        const ContextPtr& autogradContext,
        const std::shared_ptr<Node>& graphRoot,
        const edge_list& outputEdges,
        bool incrementOutstandingTasks) {
  // Cleanup previous state for outstanding RPCs. Outstanding RPCs could be
  // lingering if we\'re running backward multiple times and some of the
  // passes ran into errors.
  autogradContext->clearOutstandingRpcs();
    
  // 得到GraphTask
  auto graphTask = autogradContext->retrieveGraphTask();
  
  // 启动了一个线程来运行 execute_graph_task_until_ready_queue_empty
  at::launch([this, graphTask, graphRoot, incrementOutstandingTasks]() {
    execute_graph_task_until_ready_queue_empty(
        /*node_task*/ NodeTask(graphTask, graphRoot, InputBuffer(0)),
        /*incrementOutstandingTasks*/ incrementOutstandingTasks);
  });
    
  // Use a reference here to avoid refcount bump on futureGrads.
  // 处理结果
  auto& futureGrads = graphTask->future_result_;

  // Build a future that waits for the callbacks to execute (since callbacks
  // execute after the original future is completed). This ensures we return a
  // future that waits for all gradient accumulation to finish.
  auto accumulateGradFuture =
      c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());

  futureGrads->addCallback(
      [autogradContext, outputEdges, accumulateGradFuture](c10::ivalue::Future& futureGrads) {
        if (域名rror()) {
		  // 省略错误处理部分
          return;
        }

        try {
          const variable_list& grads =
              域名tValue().toTensorVector();
           // 标识已经结束 
          accumulateGradFuture->markCompleted(c10::IValue());
        } catch (std::exception& e) {
          accumulateGradFuture->setErrorIfNeeded(std::current_exception());
        }
      });

  return accumulateGradFuture;
}

at::launch 位于 aten/src/ATen/域名,这里会在线程之中调用传入的 func。

void launch(std::function<void()> func) {

  internal::launch_no_thread_state(std::bind([](
    std::function<void()> f, ThreadLocalState thread_locals) {
      ThreadLocalStateGuard guard(std::move(thread_locals));
      f();
    },
    std::move(func),
    ThreadLocalState()
  ));
}

namespace internal {
    void launch_no_thread_state(std::function<void()> fn) {
    #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
      intraop_launch(std::move(fn));
    #else
      get_pool().run(std::move(fn));
    #endif
    }
} 

我们接下来一一看看内部这几个方法如何执行。

2.2 execute_graph_task_until_ready_queue_empty

此函数类似 Engine::thread_main,通过一个 NodeTask 来完成本 GraphTask的执行,其中 evaluate_function 会不停的向 cpu_ready_queue 插入新的 NodeTask。域名uate_function 方法会:

  • 首先,初始化原生引擎线程。
  • 其次,每个调用建立一个 cpu_ready_queue,用来从root_to_execute开始遍历graph_task,这允许用不同的线程来对GraphTask并行执行,这是一个CPU相关的queue。
  • 把传入的 node_task 插入到 cpu_ready_queue。
  • 沿着反向计算图从根部开始,一直计算到叶子节点。
    • 这里叶子节点都是 AccumulateGrad 或者 RecvRpcBackward。

    • 如果是中间节点,则正常计算。

    • 如果是 RecvRpcBackward 则会给对应的下游节点发送 RPC 消息

    • 如果是 AccumulateGrad,则在上下文累积梯度

具体代码如下:

void DistEngine::execute_graph_task_until_ready_queue_empty(
    NodeTask&& node_task,
    bool incrementOutstandingTasks) {
  
  // 初始化原生引擎线程
  域名ialize_device_threads_pool();
  
  // Create a ready queue per call to traverse the graph_task from
  // root_to_execute This allow concurrent execution of the same GraphTask from
  // different threads
  // 每个调用建立一个 ready queue,用来从root_to_execute开始遍历graph_task,这允许用不同的线程来对GraphTask并行执行,这是一个CPU相关的queue
  std::shared_ptr<ReadyQueue> cpu_ready_queue = std::make_shared<ReadyQueue>();
  auto graph_task = 域名();
  if (graph_task == nullptr) {
    LOG(ERROR) << "GraphTask has expired for NodeTask: "
               << 域名->name() << ", skipping execution.";
    return;
  }

  cpu_ready_queue->push(std::move(node_task), incrementOutstandingTasks);

  torch::autograd::set_device(torch::autograd::CPU_DEVICE);
  graph_task->owner_ = torch::autograd::CPU_DEVICE;
  while (!cpu_ready_queue->empty()) {
    std::shared_ptr<GraphTask> local_graph_task;
    {
      // Scope this block of execution since NodeTask is not needed after this
      // block and can be deallocated (release any references to grad tensors
      // as part of inputs_)
      NodeTask task = cpu_ready_queue->pop(); // 取出一个NodeTask
      if (!(local_graph_task = 域名())) {
        continue;
      }
      if (域名 && !local_graph_task->域名()) {
        AutoGradMode grad_mode(local_graph_task->grad_mode_);
        try {
          GraphTaskGuard guard(local_graph_task);
          域名uate_function( // 这里会调用具体Node对应的函数
              local_graph_task, 域名(), 域名ts_, cpu_ready_queue);
        } catch (std::exception& e) {
          域名ad_on_exception(local_graph_task, 域名, e);
          // break the loop in error so that we immediately stop the execution
          // of this GraphTask, mark it completed if necessary and return the
          // future with proper ErrorMessage
          break;
        }
      }
    }
    // Decrement the outstanding task.
    --local_graph_task->outstanding_tasks_; // 处理了一个NodeTask
  }
  // Check if we\'ve completed execution.
  if (graph_task->completed()) {
    // We don\'t need to explicitly notify the owner thread, since
    // \'mark_as_completed_and_run_post_processing\' would mark the Future as
    // completed and this would notify the owner thread that the task has been
    // completed.
    graph_task->mark_as_completed_and_run_post_processing();
  }
}

另外,一共有三个地方调用 execute_graph_task_until_ready_queue_empty。

  1. runEngineAndAccumulateGradients 会调用,这里就是用户主动调用 backward 的情形,就是本节介绍的。
  2. executeSendFunctionAsync 会调用,这里对应了某节点从反向传播上一节点接受到梯度之后的操作,我们会在下一节介绍。
  3. globalCpuThread 会调用,这是CPU工作专用线程,我们马上会介绍。
  4. 在 域名uate_function 之中,会针对 AccumulateGrad 来累积梯度。
  5. 在 域名uate_function 之中,会调用 RecvRpcBackward 来向反向传播下游发送消息。

我们总结一下几个计算梯度的流程,分别对应下面三个数字。

 User Training Script             RPC BACKWARD_AUTOGRAD_REQ
     +                                         +
     |                                         |
     | 1                                       | 2
     v                                         v
 backward                         域名essRpc
     +                                         +
     |                                         |
     |                                         |
     v                                         v
 域名ute               域名essBackwardAutogradReq
     +                                         +
     |                                         |
     |                                         |
     |                                         v
     |              +----------+  域名uteSendFunctionAsync
     |              |                               +
     |              |                               |
     v              v                               |
域名uteDependencies                      |
     |                                              |
     |                                              |
     v                                              |
 域名ngineAndAccumulateGradients         |     域名alCpuThread
     +                                              |                   +
     |                           +------------------+                   |
     |                           |                                      | 3
     |                           |             +------------------------+
     |                           |             |
     |                           |             |
     v                           v             v
 域名ute_graph_task_until_ready_queue_empty
     +
     |
     |
     v
 域名uate_function
     +
     |
     +--------------------------------------------------------------+
     |                                                              |
     |  4 AccumulateGrad                                            | 5  RecvRpcBackward
     v                                                              v

(*hook)(captured_grad)                            call_function(graph_task, func, inputs)

2.3 evaluate_function

上面代码之中,实际上会调用原生引擎的 evaluate_function 来完成操作。

我们看看如何使用 exec_info_,如果没有设置为需要执行,则就不处理。在此处,我们可以看到 上文提到的recvBackwardEdges 如何与 exec_info_ 交互。

遍历 recvBackwardEdges,对于每个 recvBackward,在 域名_info_ 之中对应项之上设止为需要执行。

具体代码如下,这里会:

  • 针对 AccumulateGrad 来累积梯度。
  • 调用 RecvRpcBackward 来向反向传播下游发送消息。
void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!域名y()) {
    auto& fn_info = 域名(func);
    if (auto* capture_vec = 域名()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        auto& captured_grad = graph_task->captured_vars_[域名ut_idx_];
        captured_grad = inputs[域名t_idx_];
        for (auto& hook : 域名s_) {
          captured_grad = (*hook)(captured_grad); //这里调用 hook,就是 DistAccumulateGradCaptureHook 的 operator(),captured_grad 就是累积的梯度
        }
      }
    }
    if (!域名ed_) { 
      // Skip execution if we don\'t need to execute the function.
      return; // 如果没有设置需要执行,则直接返回。recvBackward 会设置需要执行
    }
  }
  
  // 这里就是调用 recvBackward
  auto outputs = call_function(graph_task, func, inputs);
    
  // 后续代码省略  

2.4 globalCpuThread

globalCpuThread 可以参见上文的 [GPU to CPU continuations] 一节,globalCpuThread是工作线程,其就是从 ready queue 里面弹出 NodeTask,然后执行。

对于globalCpuThread,其参数 ready_queue 是 global_cpu_ready_queue_

void DistEngine::globalCpuThread(
    const std::shared_ptr<ReadyQueue>& ready_queue) {
  while (true) {
    NodeTask task = ready_queue->pop();
    if (域名utdownTask_) {
      // Need to shutdown this thread.
      break;
    }

    auto graphTask = 域名();
    if (graphTask == nullptr) {
      // GraphTask has expired, ignore and continue processing.
      continue;
    }

    // Launch the execution on a JIT thread.
    at::launch([this,
                graphTask,
                graphRoot = 域名,
                variables =
                    InputBuffer::variables(std::move(域名ts_))]() mutable {
      InputBuffer inputs(域名());
      for (size_t i = 0; i < 域名(); i++) {
        域名(i, std::move(variables[i]), c10::nullopt, c10::nullopt);
      }
      execute_graph_task_until_ready_queue_empty( // 这里会调用
          /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
          /*incrementOutstandingTasks*/ false);
    });
  }
}

对于普通引擎也会设置一个 cpu 专用 queue。

auto graph_task = std::make_shared<GraphTask>(
    /* keep_graph */ keep_graph,
    /* create_graph */ create_graph,
    /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
    /* cpu_ready_queue */ local_ready_queue);

2.5 小结

对于分布式引擎,与普通引擎在计算部分主要不同之处为:

  • 如果是 RecvRpcBackward 则会给对应的下游节点发送 RPC 消息

  • 如果是 AccumulateGrad,则在上下文累积梯度

所以我们接下来看看具体这两部分如何处理。

0x03 RPC调用

在之前文章中,我们看到了接受方如何处理反向传播 RPC 调用,我们接下来看看引擎如何发起反向传播 RPC 调用,就是如何调用 recv 方法。

这里就适用于下面worker 0 调用 recv ,执行来到 worker 1 这种情况,对应设计文档中如下。

当 autograd 引擎执行该recv函数时,该recv 函数通过 RPC 将输入梯度发送到适当的worker。每个recv函数都知道目标 worker id,因为它被记录为前向传播的一部分。通过autograd_context_idautograd_message_idrecv函数被发送到远程主机。

img

我们就看看如何执行 recv 函数。

具体结合到分布式引擎,就是当引擎发现某一个 Node 是 RecvRpcBackward,就调用其 apply 函数

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!域名y()) {
    // 省略了梯度累积部分代码,具体可以参见上面章节 
    if (!域名ed_) { 
      // Skip execution if we don\'t need to execute the function.
      return; // 如果没有设置需要执行,则直接返回。recvBackward 会设置需要执行
    }
  }
  
  // 这里就是调用 域名y 函数
  auto outputs = call_function(graph_task, func, inputs);
    
  // 后续代码省略  

3.1 RecvRpcBackward

3.1.1 定义

RecvRpcBackward 定义如下,

class TORCH_API RecvRpcBackward : public torch::autograd::Node {
 public:
  explicit RecvRpcBackward(
      const AutogradMetadata& autogradMetadata,
      std::shared_ptr<DistAutogradContext> autogradContext,
      rpc::worker_id_t fromWorkerId,
      std::unordered_map<c10::Device, c10::Device> deviceMap);

  torch::autograd::variable_list apply(
      torch::autograd::variable_list&& grads) override;

 private:
  const AutogradMetadata autogradMetadata_;

  // Hold a weak reference to the autograd context to avoid circular
  // dependencies with the context (since it holds a reference to
  // RecvRpcBackward).
  std::weak_ptr<DistAutogradContext> autogradContext_;

  // The worker id from which the RPC was received. During the backward pass,
  // we need to propagate the gradients to this workerId.
  rpc::worker_id_t fromWorkerId_;

  // Device mapping for tensors sent over RPC.
  const std::unordered_map<c10::Device, c10::Device> deviceMap_;
};

3.1.2 构建

构造函数如下。

RecvRpcBackward::RecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    ContextPtr autogradContext,
    rpc::worker_id_t fromWorkerId,
    std::unordered_map<c10::Device, c10::Device> deviceMap)
    : autogradMetadata_(autogradMetadata),
      autogradContext_(std::move(autogradContext)),
      fromWorkerId_(fromWorkerId),
      deviceMap_(std::move(deviceMap)) {}

3.1.3 apply

torch/csrc/distributed/autograd/functions/域名 定义了其 apply 函数,其作用就是:

  • 把传入的梯度 grads 放入outputGrads,因为要输出给下一环节。
  • 构建 PropagateGradientsReq,这就是 BACKWARD_AUTOGRAD_REQ。
  • 发送 RPC 给下一环节。
variable_list RecvRpcBackward::apply(variable_list&& grads) {
  std::vector<Variable> outputGrads;
  for (size_t i = 0; i < 域名(); i++) { // 下面就是把传入的梯度 grads 放入outputGrads
    const auto& grad = grads[i];
    if (域名ned()) {
      域名ace_back(grad);
    } else {
      // Put in zeros for a tensor with no grad.
      域名ace_back(input_metadata(i).zeros_like());
    }
  }
 
  auto sharedContext = 域名();
  // Send the gradients over the wire and record the future in the autograd
  // context.
  PropagateGradientsReq gradCall( // 构建 PropagateGradientsReq
      autogradMetadata_,
      outputGrads,
      sharedContext->retrieveGraphTask()->keep_graph_);

  // Send the gradients over to the appropriate node.
  auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
  auto jitFuture = rpcAgent->send( // 发送 RPC
      rpcAgent->getWorkerInfo(fromWorkerId_),
      std::move(gradCall).toMessage(), // 调用了toMessageImpl
      rpc::kUnsetRpcTimeout,
      deviceMap_);

  // Record the future in the context.
  sharedContext->addOutstandingRpc(jitFuture);

  // \'recv\' function sends the gradients over the wire using RPC, it doesn\'t
  // need to return anything for any downstream autograd function.
  return variable_list();
}

因为这里发送了 PropagateGradientsReq,所以我们接着看。

3.2 PropagateGradientsReq

3.2.1 定义

PropagateGradientsReq 扩展了 RpcCommandBase。

// Used to propagate gradients from one node to another during a distributed
// backwards pass. This RPC call is invoked when we hit a `recv` autograd
// function during backward pass execution.
class TORCH_API PropagateGradientsReq : public rpc::RpcCommandBase {
 public:
  PropagateGradientsReq(
      const AutogradMetadata& autogradMetadata,
      std::vector<torch::autograd::Variable> grads,
      bool retainGraph = false);

  const AutogradMetadata& getAutogradMetadata();

  const std::vector<torch::autograd::Variable>& getGrads();

  // Serialization and deserialization methods.
  rpc::Message toMessageImpl() && override;
  static std::unique_ptr<PropagateGradientsReq> fromMessage(
      const rpc::Message& message);

  // Whether or not to retain the autograd graph.
  bool retainGraph();

 private:
  AutogradMetadata autogradMetadata_;
  std::vector<torch::autograd::Variable> grads_;
  bool retainGraph_;
};

其 toMessageImpl 指明了本消息是 BACKWARD_AUTOGRAD_REQ。

Message PropagateGradientsReq::toMessageImpl() && {
  std::vector<at::IValue> ivalues;
  // Add all the grad tensors.
  for (const auto& grad : grads_) {
    域名ace_back(grad);
  }

  // Now add autograd metadata.
  域名ace_back(域名gradContextId);
  域名ace_back(域名gradMessageId);

  // Add retain graph.
  域名ace_back(retainGraph_);

  // Now pickle using JIT pickler.
  std::vector<torch::Tensor> tensorTable;
  std::vector<char> payload =
      jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);

  return Message(
      std::move(payload),
      std::move(tensorTable),
      MessageType::BACKWARD_AUTOGRAD_REQ); // 这里指明了消息类型。
}

3.3 接受方

为了论述完整,我们接下来看看接收方如何处理反向传播。

3.3.1 接受消息

在生成 TensorPipeAgent 时候,把 RequestCallbackImpl 配置为回调函数。这是 agent 的统一响应函数。前面关于代理接收逻辑时候,我们也提到了,会进入 RequestCallbackNoPython::processRpc 函数。其中可以看到有对 BACKWARD_AUTOGRAD_REQ 的处理逻辑。

这种是 RPC 的正常流程。

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

  switch (messageType) {

    case MessageType::BACKWARD_AUTOGRAD_REQ: { 
      processBackwardAutogradReq(rpc, messageId, responseFuture); // 这里调用
      return;
    };

3.3.2 processBackwardAutogradReq

在 processBackwardAutogradReq 之中会:

  • 获取 DistAutogradContainer。
  • 获取 上下文。
  • 调用 executeSendFunctionAsync 进行引擎处理。

由此,我们可以看到有两个途径进入引擎:

  • 一个是示例代码显式主动调用 backward,进而调用到 DistEngine::getInstance().execute,就是 worker 0。
  • 一个是被动调用 DistEngine::getInstance().executeSendFunctionAsync,就是 worker 1。
void RequestCallbackNoPython::processBackwardAutogradReq(
    RpcCommandBase& rpc,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture) const {
  auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
  const auto& autogradMetadata = 域名utogradMetadata();

  // Retrieve the appropriate autograd context.
  auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
      域名gradContextId); // 得到发送者的context id

  // Lookup the appropriate \'send\' function to enqueue.
  std::shared_ptr<SendRpcBackward> sendFunction = // 依据发送者context id和消息id得到sendFunction
      autogradContext->retrieveSendFunction(域名gradMessageId);

  // Attach the gradients to the send function.
  sendFunction->setGrads(域名rads()); // 设置梯度

  // Now execute the autograd graph using the "distributed engine."
  auto execFuture = DistEngine::getInstance().executeSendFunctionAsync( // 调用引擎
      autogradContext, sendFunction, 域名inGraph());

  // Our response is satisfied when the rpcs come back.
  execFuture->addCallback([responseFuture, messageId](JitFuture& execFuture) {
    if (!域名rror()) {
      Message m = std::move(PropagateGradientsResp()).toMessage();
      域名d(messageId);
      responseFuture->markCompleted(
          IValue(c10::make_intrusive<Message>(std::move(m))));
    } else {
      responseFuture->setError(域名ption_ptr());
    }
  });
}

3.3.3 executeSendFunctionAsync

executeSendFunctionAsync 这里开始进入了引擎,注意,这里是接收方也进入了引擎,在接收方上进行计算。executeSendFunctionAsync 会直接调用 execute_graph_task_until_ready_queue_empty,也可能先计算依赖然后继续执行。此处可以参考设计之中的:

  • 6)当远程主机收到这个请求时,我们使用 autograd_context_idautograd_message_id来查找适当的send函数。
  • 7)如果这是worker第一次收到对给定 autograd_context_id的请求,它将按照上面的第 1-3 点所述在本地计算依赖项。
  • 8)然后将在第6点接受到的send方法插入队列,以便在该worker的本地 autograd 引擎上执行。

具体代码如下:

c10::intrusive_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
    const ContextPtr& autogradContext,
    const std::shared_ptr<SendRpcBackward>& sendFunction,
    bool retainGraph) {

  // Typically the local autograd engine ensures stream synchronizations between
  // nodes in the graph. However, for distributed autograd the sendFunction
  // inputs might have been retrieved over the wire on a separate stream and the
  // sendFunction itself runs on a different stream. As a result, we need to
  // manually synchronize those two streams here.
  const auto& send_backward_stream = sendFunction->stream(c10::DeviceType::CUDA);
  if (send_backward_stream) { // 拿到本次执行对应的Stream
    for (const auto& grad : sendFunction->getGrads()) {
        const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
        const auto default_stream = 域名tream(域名ce());
        if (send_backward_stream != default_stream) {
          auto event = c10::Event{c10::DeviceType::CUDA};
          域名rd(default_stream);
          send_backward_stream->wait(event); // 需要同步,保证当前操作完成
        }
    }
  }

  std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
  if (域名(autogradContext->contextId()) ==
      域名()) { // 遍历,查找sendFunction对应的上下文是否在本节点之中已经记录
    // 没有找到上下文,需要计算依赖
    edge_list outputEdges;
    // Pass in a dummy graphRoot since all send functions are the roots.
    auto dummyRoot = std::make_shared<GraphRoot>(edge_list(), variable_list());
    computeDependencies( // 计算依赖
        autogradContext, {}, {}, dummyRoot, outputEdges, retainGraph);

    // Mark the autograd context id as initialized and unlock.
    域名rt(autogradContext->contextId());
    域名ck();

    // Enqueue the current send function.
    auto graphTask = autogradContext->retrieveGraphTask();
    // Run the autograd engine.
    auto accumulateGradFuture = runEngineAndAccumulateGradients( // 计算梯度
        autogradContext,
        sendFunction,
        outputEdges,
        /*incrementOutstandingTasks=*/false);

    // Build the \'uber\' future that waits for everything.
    auto callbackFuture =
        c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
    // 注册回调
    accumulateGradFuture->addCallback([autogradContext,
                                       callbackFuture](c10::ivalue::Future& accumulateGradFuture) {
      try {
        if (域名rror()) {
          // Perform cleanup at the end of the backward pass (before we mark
          // the future as completed).
          DistEngine::getInstance().cleanupBackwardPass(autogradContext);

          // Skip any further processing on errors.
          callbackFuture->setError(域名ption_ptr());
          return;
        }

        // Wait for all RPCs after the autograd engine is done.
        auto rpcFuture = autogradContext->clearAndWaitForOutstandingRpcsAsync();
        rpcFuture->addCallback([callbackFuture, autogradContext](c10::ivalue::Future& rpcFuture) {
          try {
            // Perform cleanup at the end of the backward pass (before
            // we mark the future as completed).
            DistEngine::getInstance().cleanupBackwardPass(autogradContext);
          } catch (std::exception& e) {
            callbackFuture->setErrorIfNeeded(std::current_exception());
            return;
          }

          // Finally mark the \'uber\' future as completed.
          if (!域名rror()) {
            callbackFuture->markCompleted(c10::IValue());
          } else {
            callbackFuture->setError(域名ption_ptr());
          }
        });
      } catch (std::exception& e) {
        callbackFuture->setErrorIfNeeded(std::current_exception());
      }
    });

    // Return the future which waits for all async processing to be done.
    return callbackFuture;
  } else { // 可以在当前Node找到上下文
    域名ck();
    auto graphTask = autogradContext->retrieveGraphTask();
    at::launch([this, graphTask, sendFunction]() {
      execute_graph_task_until_ready_queue_empty(
          /*node_task*/ NodeTask(graphTask, sendFunction, InputBuffer(0)),
          /*incrementOutstandingTasks*/ false);
    });
    auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
    fut->markCompleted(c10::IValue());
    return fut;
  }
}

具体如下图:

                                                                  +
                                                         worker 0 | worker 1
                                                                  |
  Engine            RecvRpcBackward              RpcAgent         |     RequestCallbackNoPython             DistEngine
    +                    +                          +             |              +                              +
    |                    |                          |             |              |                              |
    |                    |                          |             |              |                              |
evaluate_function        |                          |             |              |                              |
    +                    |                          |             |              |                              |
    |                    |                          |             |              |                              |
    +                    |                          |             |              |                              |
  call_function          |                          |             |              |                              |
    +                    |                          |             |              |                              |
    |      grads         v                          |             |              |                              |
    +----------------> apply                        |             |              |                              |
    |                    +                          |             |              |                              |
    |                    |                          |             |              |                              |
    |                    +                          |             |              |                              |
    |                 gradCall                      |             |              |                              |
    |                    +                          |             |              |                              |
    |                    |  PropagateGradientsReq   |             |              |                              |
    |                    +------------------------> |             |              |                              |
    |                    |                          |             +              |                              |
    |                    |                          +   BACKWARD_AUTOGRAD_REQ    |                              |
    |                    |                        send  +---------+--------->    |                              |
    |                    |                          +             |              |                              |
    |                    |                          |             |              +                              |
    |                    |                          |             |     processBackwardAutogradReq              |
    |                    |                          |             |              +                              |
    |                    |                          |             |              |                              +
    |                    |                          |             |              +------------> executeSendFunctionAsync
    |                    |                          |             |              |                              +
    |                    |                          |             |              |                              |
    |                    |                          |             |              |                              |
    v                    v                          v             +              v                              v


手机如下:

0x04 DistAccumulateGradCaptureHook

目前看起来总体逻辑已经完成了,但是实际上缺了一块,对应了设计文档中的:

最后,我们不是在 Tensor的.grad之上累积梯度,而是在每个Distributed Autograd Context之上分别累积梯度 。梯度存储在Dict[Tensor, Tensor]之中 ,Dict[Tensor, Tensor]基本上是从 Tensor 到其关联梯度的映射,并且可以使用 get_gradients() API检索该映射 。

就是把异地/本地的梯度累积到本地上下文之中,所以我们再分析一下 DistAccumulateGradCaptureHook。

4.1 定义

DistAccumulateGradCaptureHook 有三个作用:

  1. 调用原始AccumulateGrad的 pre hooks 来修改输入梯度。

  2. 将 grad 累积到RPC上下文。

  3. 调用原始AccumulateGrad的 post hooks。

其定义如下:

// This hook does 3 things:
//   1. Call pre hooks of the original AccumulateGrad to modify the input grad.
//   2. Accumuate the gard to RPC context.
//   3. Call post hooks of the original AccumulateGrad.
class DistAccumulateGradCaptureHook
    : public GraphTask::ExecInfo::Capture::GradCaptureHook {
 public:
  DistAccumulateGradCaptureHook(
      std::shared_ptr<AccumulateGrad> accumulateGrad,
      ContextPtr autogradContext)
      : accumulateGrad_(std::move(accumulateGrad)),
        autogradContext_(std::move(autogradContext)) {}

  at::Tensor operator()(const at::Tensor& grad) override {
    ThreadLocalDistAutogradContext contextGuard{ContextPtr(autogradContext_)};
    variable_list inputGrads = {grad};
    // It\'s intended that pre/post hooks are still called even if the grad is
    // undenfined here.
    for (const auto& hook : accumulateGrad_->pre_hooks()) {
      inputGrads = (*hook)(inputGrads); // 调用 pre-hooks
    }

    // It is possible that the grad is not defined since a separate
    // invocation of the autograd engine on the same node might actually
    // compute this gradient.
    if (inputGrads[0].defined()) {
      // There are 3 internal references to \'inputGrads[0]\' at this moment:
      //   1. \'inputGrads[0]\' in this function.
      //   2. \'graph_task->captured_vars_\' on the callsite in the local engine.
      //   3. \'InputBuffer& inputs\' on the callsite as the inputs of the
      //   function node.
      autogradContext_->accumulateGrad( // 累积梯度
          accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);
    }
    const variable_list kEmptyOuput;
    for (const auto& hook : accumulateGrad_->post_hooks()) {
      (*hook)(kEmptyOuput, inputGrads); // 调用 post-hooks
    }
    return inputGrads[0];
  }

 private:
  std::shared_ptr<AccumulateGrad> accumulateGrad_; // 这就是需要累积的目标向量,后续操作在其之上
  ContextPtr autogradContext_;
};

4.2 生成

如何生成 DistAccumulateGradCaptureHook?计算依赖时候生成 DistAccumulateGradCaptureHook,但是记录在 域名_back 之中。

这里是为了处理 AccumulateGrad。

  • AccumulateGrad 一定是叶子节点,不需执行,而需要在其上积累梯度,但是RecvRpcBackward需要执行。

  • AccumulateGrad 就保存在 DistAccumulateGradCaptureHook 之中。

void DistEngine::computeDependencies(
    const ContextPtr& autogradContext,
    const edge_list& rootEdges,
    const variable_list& grads,
    const std::shared_ptr<Node>& graphRoot,
    edge_list& outputEdges,
    bool retainGraph) {
  
  if (!域名y()) {
    // Compute \'needed execution\' starting from all \'send\' functions and the
    // original graphRoot.
    edge_list edges;
    // Create some dummy edges (input_nr not important for init_to_execute).
    for (const auto& mapEntry : sendFunctions) {
      域名ace_back(域名nd, 0);
    }

    // Add the original graphRoot as an edge.
    域名ace_back(graphRoot, 0);

    // Create a dummy GraphRoot and run init_to_execute with it.
    GraphRoot dummyRoot(edges, {});
    graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0);
    for (auto& mapEntry : graphTask->exec_info_) {
      auto& execInfo = 域名nd;
      if (!域名ures_) {
        continue;
      }
      auto fn = 域名t;
      // There may be nodes other than \'AccumulateGrad\', e.g. RecvRPCBackward,
      // to be captured.
      if (auto accumulateGradFn = dynamic_cast<AccumulateGrad*>(fn)) {
        for (auto& capture : *域名ures_) {
          域名_back( // 这里会生成
              std::make_unique<DistAccumulateGradCaptureHook>(
                  std::dynamic_pointer_cast<AccumulateGrad>( // 会保存 AccumulateGrad
                      accumulateGradFn->shared_from_this()),
                  autogradContext));
        }
      }
    }

    // Mark all \'RecvRPCBackward\' as needing execution.
    for (const auto& recvBackwardEdge : recvBackwardEdges) {
      graphTask->exec_info_[域名()].needed_ = true;
    }
  }  
}

4.3 使用

代码是缩减版。

首先,execute_graph_task_until_ready_queue_empty 会调用到原始引擎 域名uate_function。

void DistEngine::execute_graph_task_until_ready_queue_empty(
    NodeTask&& node_task,
    bool incrementOutstandingTasks) {

  while (!cpu_ready_queue->empty()) {
    std::shared_ptr<GraphTask> local_graph_task;
    {
      NodeTask task = cpu_ready_queue->pop();

      if (域名 && !local_graph_task->域名()) {
        AutoGradMode grad_mode(local_graph_task->grad_mode_);
        GraphTaskGuard guard(local_graph_task);
        域名uate_function( // 调用原始引擎
              local_graph_task, 域名(), 域名ts_, cpu_ready_queue);
      }
    }
    // Decrement the outstanding task.
    --local_graph_task->outstanding_tasks_;
  }

}

其次,原始引擎代码之中,会调用hooks。

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!域名y()) {
    auto& fn_info = 域名(func);
    if (auto* capture_vec = 域名()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        auto& captured_grad = graph_task->captured_vars_[域名ut_idx_];
        captured_grad = inputs[域名t_idx_];
        for (auto& hook : 域名s_) {
          captured_grad = (*hook)(captured_grad); // 这里调用 hook,就是 DistAccumulateGradCaptureHook 的 operator(),captured_grad 就是累积的梯度
        }
      }
    }
  }
  
  // 后续省略

DistAccumulateGradCaptureHook 的 operator() 方法之中,会调用下面来累积梯度。

  autogradContext_->accumulateGrad(
      accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);

4.4 累积梯度

4.4.1 上下文累积

void DistAutogradContext::accumulateGrad(
    const torch::autograd::Variable& variable, // variable就是目标变量
    const torch::Tensor& grad, // grad就是梯度,需要累积到variable之上
    size_t num_expected_refs) {

  std::lock_guard<std::mutex> guard(lock_);
  auto it = 域名(variable);
  at::Tensor old_grad;
  if (it != 域名()) {
    // Accumulate multiple grads on the same variable.
    old_grad = it->value();
  }

  // Gradients are computed using the forward streams. Local autograd
  // engine uses AccumulateGrad function to retrieve and apply forward
  // stream during the backward computation. In distributed autograd,
  // we directly call AccumulateGrad::accumulateGrad, and skip the
  // CUDA stream restoration from autograd function. Hence, we manually
  // call it here to get the streams correct.
  auto forward_stream =
      torch::autograd::impl::grad_accumulator(variable)->stream(
          域名ce().type());
  c10::OptionalStreamGuard stream_guard(forward_stream);

  // No higher order gradients supported in distributed autograd.
  AutoGradMode grad_mode(false);
  at::Tensor new_grad = AccumulateGrad::callHooks(variable, grad); // 计算

  AccumulateGrad::accumulateGrad( // 调用算子函数来累积梯度
      variable,
      old_grad,
      new_grad,
      // Add +1 here since we can\'t std::move(grad) when call
      // AccumulateGrad::callHooks, since it is a const ref, and that incurs a
      // refcount bump for the new_grad.
      num_expected_refs + 1,
      [this, &variable](at::Tensor&& grad_update) {
        auto device = 域名ce();
        域名rt(variable, std::move(grad_update));
        recordGradEvent(device);
      });
}

4.4.2 算子累积

代码位于 torch/csrc/autograd/functions/accumulate_grad.h。AccumulateGrad 的定义如下:

struct TORCH_API AccumulateGrad : public Node {
  explicit AccumulateGrad(Variable variable_);

  variable_list apply(variable_list&& grads) override;

  static at::Tensor callHooks(
      const Variable& variable,
      at::Tensor new_grad) {
    for (auto& hook : impl::hooks(variable)) {
      new_grad = (*hook)({new_grad})[0];
    }
    return new_grad;
  }

  // Given a variable with its current grad as variable_grad, accumulates
  // new_grad into variable_grad if in place accumulation is possible.
  // Otherwise, uses \'update_grad\' to update the grad for the variable.

  // "Gradient Layout Contract"
  //
  // AccumulateGrad tries to stash strided (non-sparse) grads with memory layout
  // (strides) such that variables and grads interact efficiently in later
  // optimizer kernels, and grads interact efficiently with c10d::域名.
  //
  // Specifically, AccumulateGrad tries to ensure the following
  // (cf torch/csrc/autograd/utils/grad_layout_contract.h):
  //   (1) if 域名on_overlapping_and_dense(), the stashed grad\'s
  //       strides match variable.
  //   (2) else, stashed grad is rowmajor contiguous.
  // If variable\'s grad does not exist (!域名ned())
  // AccumulateGrad steals new_grad if it\'s stealable and obeys the contract
  // already, otherwise it deep copies new_grad into an obedient clone.
  //
  // If variable\'s grad already exists (域名ned()), new_grad must
  // be added to variable_grad.  If we aren\'t setting up for double backward
  // (!GradMode::is_enabled()), AccumulateGrad performs "variable_grad += new_grad"
  // in-place, which keeps variable_grad\'s layout. We assume (hope) variable_grad
  // was created obeying (1) or (2) at some point in the past.
  //
  // If we are setting up for double backward, AccumulateGrad updates the grad
  // out-of-place via "variable_grad + new_grad."  TensorIterator operator+ decides
  // result\'s layout.  Typically TensorIterator matches strides of the first arg,
  // so we once again assume (hope) variable_grad was originally created obeying
  // (1) or (2).
  //
  // AccumulateGrad does not enforce the contract with 100% certainty.  Examples:
  //  - If a user manually permutes a param or its grad, then runs a fwd+bwd,
  //    variable_grad += new_grad keeps variable_grad\'s layout without rechecking
  //    the contract.
  //  - If TensorIterator changes its corner cases about operator+\'s result
  //    (for example, giving more or less priority to channels_last inputs, see
  //    https://域名/pytorch/pytorch/pull/37968) the result may not obey.
  //
  // Fortunately, if a given grad doesn\'t satisfy (1) or (2), the penalty is
  // degraded performance in 域名 or optimizer kernels, not death by
  // assert or silently bad numerics.

  // variable: the variable whose grad we\'re accumulating.
  // variable_grad: the current grad for the variable.
  // new_grad: new grad we want to acummulate for the variable.
  // num_expected_refs: the number of refs we expect to hold internally
  //                    such that it is safe to avoid cloning the grad
  //                    if use_count() of the grad is less than or equal
  //                    to this value (in addition to post_hooks).
  // update_grad: Function that is used to update grad for the variable.
  //              The argument to the function is a Tensor which
  //              is used to set a new value for the grad.
  template <typename T>
  static void accumulateGrad( // 这里会进行具体的累积梯度
      const Variable& variable,
      at::Tensor& variable_grad,
      const at::Tensor& new_grad,
      size_t num_expected_refs,
      const T& update_grad) {
    if (!域名ned()) {
      if (!GradMode::is_enabled() &&
          !域名parse() &&
          域名count() <= num_expected_refs &&
          (域名kldnn() || utils::obeys_layout_contract(new_grad, variable))) {
        // we aren\'t setting up for double-backward
        // not sparse
        // no other user-visible tensor references new_grad
        // new_grad obeys the "Gradient Layout Contract", there has a special case,
        // For MKLDNN tensor, which is a opaque tensor, assuming it obeys layout_contract.
        // Under these conditions, we can steal new_grad without a deep copy.
        update_grad(域名ch());
      } else if (
          !GradMode::is_enabled() && 域名parse() &&
          域名ices().is_contiguous() &&
          域名ues().is_contiguous() &&
          // Use count for indices and values should always be <=1 since the
          // SparseTensor should be the only one holding a reference to these.
          域名ices().use_count() <= 1 &&
          域名ues().use_count() <= 1 &&
          域名count() <= num_expected_refs) {
        // Can\'t detach sparse tensor (since metadata changes are not allowed
        // after detach), so just create a new one for the grad which is a
        // shallow copy. We need a shallow copy so that modifying the original
        // grad tensor doesn\'t modify the grad we accumulate.
        // We only skip clone if indices and values themselves are contiguous
        // for backward compatiblity reasons. Since without this optimization,
        // earlier we would clone the entire SparseTensor which cloned indices
        // and values.
        // For details see https://域名/pytorch/pytorch/issues/34375.
        update_grad(at::_sparse_coo_tensor_unsafe(
            域名ices(),
            域名ues(),
            域名s(),
            域名                
标签:编程
湘ICP备14001474号-3  投诉建议:234161800@qq.com   部分内容来源于网络,如有侵权,请联系删除。