天天看點

OneFlow: 從 Job 到 Plan

前言

前面分析了如何從一個個 Op 變到 Job,這篇将分析如何從一個個 Job 變成一個 Plan。

Plan

首先來分析看看我們的目标是什麼?我們的目标就是一個實體上可以執行的 Plan。OneFlow 在計算上的設計采用了 Actor 機制,計算圖上的每個節點由一個 Actor 完成執行。那麼 Plan 是如何為 Actor 機制做抽象的呢?我覺得 Actor 由計算和存儲組成,計算需要考慮算子和 kernel,存儲位置需要看 Regst。是以,作為 Actor 機制的上層抽象的 Plan 需要如何抽象呢?

我們來看看 Plan 這個資料結構。

  • 首先看到第一個屬性,是一個 repeated TaskProto。Plan 和 TaskProto 之間的關系,就好像計算圖和計算節點之間的關系一樣。是以每個 Actor 的具體配置,還得看 TaskProto。
  • JobConfs,實際上是 job_id2job_conf,從 id 到具體的 Job 配置的映射。
  • 其他屬性和記憶體管理、SBP 相關。
message Plan {
  repeated TaskProto task = 1;
  required MemBlockAndChunkList block_chunk_list = 2;
  required JobConfs job_confs = 4;
  required CollectiveBoxingPlan collective_boxing_plan= 5;
  required CtrlRegstDescInfo ctrl_regst_desc_info = 6;
  map<int64, OpAttributeRefTable> job_id2op_attribute_ref_table = 7;
}

message JobConfs {
  map<int64, JobConfigProto> job_id2job_conf = 1;
}
           

TaskProto

前面我們看到 TaskProto 是 Plan 中一個可以重複的屬性,就好像計算圖中的計算節點一樣。我們需要把關注的焦點放到計算和存儲。Actor 和 Actor 之間的關聯不是顯式的聲明出來的,它隐藏在 Actor 的機制中。在一個 Actor 的輸入和輸出準備好了之後,就執行計算。是以,Actor 之間的關聯并不需要顯式聲明出來,通過輸入和輸出關聯起來即可。

message TaskProto {
  // common
  required TaskType task_type = 1;
  required int64 machine_id = 2;
  required int64 thrd_id = 3;
  required int64 task_id = 4;
  required int64 job_id = 5;
  required TaskSetInfo task_set_info = 6;
  required ExecSequence exec_sequence = 7;
  map<string, RegstDescProto> produced_regst_desc = 8;
  map<string, RegstDescIdSet> consumed_regst_desc_id = 9;
  optional bool all_register_num_eq_one_hint = 10 [default = false];
  // compute task
  optional ParallelContext parallel_ctx = 1000; // CompTask
};
           
  • 計算: 與之有關系的是 ExecSequence。我們跟着這樣的嵌套關系看:ExecSequence -> ExecNodeProto -> KernelConf -> OpAttribute -> OperatorConf。于是我們找到了這個 Actor 上執行的 Op。
  • 存儲:在 TaskProto 裡面,有 produced_regst_desc 和 consumed_regst_desc_id,表示這個 Actor 存儲的輸入和輸出。再來看 ExecNodeProto 中的屬性,有一個 map,這個 map 存儲了 op 中 blob name 到 regst desc id 之間的映射。
message ExecNodeProto {
  required KernelConf kernel_conf = 1;
  map<string, int64> bn_in_op2regst_desc_id = 2;
}

message ExecSequence {
  repeated ExecNodeProto exec_node = 1;
}
           

有個地方值得關注:在 OperatorConf 裡面,有一個 op_type 屬性,這個屬性是一個 oneof,其中有一個是 UserOpConf,這個就是使用者定義算子的配置:名字、輸入、輸出、屬性。

message UserOpConf {
  message ListString {
    repeated string s = 1;
  }
  required string op_type_name = 1;
  map<string, ListString> input = 2;
  map<string, ListString> output = 3;
  map<string, AttrValue> attr = 4;
}
           

整體流程

上一篇啟動 Session 的時候,已經分析過整體流程了,沒有深入細節。這裡再簡單複述一下。

調用流程

  • StartLazyGlobalSession 由 Python 調用
  • JUST(Global::Get()->Init(job_set)) 啟動全局對象 OneFlow 的初始化
  • CompileJobsAndPushMergedPlan 編譯并推送 Plan
  • CompileJobsAndMergePlans 編譯多個 Job 為一個 Plan

CompileJobsAndMergePlans 的主要工作如下所示:

  • 添加 Model IO Job
  • 添加 Push Job 和 Pull Job
  • CompileCurJobOnMaster 逐個編譯 Job,MergeSubPlan 将 Job 合并
  • Job 之間的記憶體複用和記憶體共享
  • FinishGlobalCriticalSectionDesc 劃分臨界區
  • MainJob 的生成、編譯、連結

後面主要關注單個 Job 的編譯,MainJob 的生成、編譯、連結。

編譯單個 Job

  • CompileCurJobOnMaster 正如名字所表達的那樣,在 Master 上編譯目前 Job,非 Master 節點不編譯,隻需要等待 Master 發送 Plan 過來。
// oneflow/core/job/oneflow.cpp: 203
Maybe<void> CompileCurJobOnMaster(Job* job, Plan* plan, bool need_job_complete) {
  const JobDesc& job_desc = GlobalJobDesc();
  if (GlobalProcessCtx::IsThisProcessMaster()) {
    double start = GetCurTime();
    Compiler().Compile(job, plan, need_job_complete);
    PlanUtil::GenMemBlockAndChunk4Plan(plan);

    LOG(INFO) << "\njob_id: " << job_desc.job_id() << " , job_name: " << job_desc.job_name()
              << " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n";
    if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
      TeePersistentLogStream::Create(StrCat("subplan_job_", job_desc.job_id()))->Write(*plan);
    }
  }
  PlanUtil::GenCollectiveBoxingPlan(job, plan);
  PlanUtil::GenRegisterHint(plan);
  return Maybe<void>::Ok();
}
           

在 Compile 這個方法裡面,通過注釋可以看到編譯分為五步。

  • 確定 job 是 completed。
  • 建立全局 OpGraph 對象
  • 建構 TaskGraph,一開始困惑于為什麼不需要傳參,原來 TaskGraph 在構造函數裡面,通過擷取全局的 OpGraph 對象來實作初始化!
  • 将 TaskGraph 放入 Plan
  • plan 後處理和清理全局對象 OpGraph
void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
  // Step1: ensure job is completed.
  if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); }

  // Step2: new Global<OpGraph> and set log configs.
  Global<OpGraph>::New(*job);
  const JobDesc& job_desc = GlobalJobDesc();
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()
      || Global<ResourceDesc, ForSession>::Get()->enable_dry_run()) {
    TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(*job);
    Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_" + std::to_string(job_desc.job_id())
                                              + "_op_graph.dot");
  }

  // Step3: build task_gph.
  // TODO(levi): we can rewrite this part of code in visitor pattern.
  auto task_gph = std::make_unique<TaskGraph>();
  using std::placeholders::_1;
  task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
  task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
  task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
  task_gph->TopoForEachNode(&TaskNode::Build);
  task_gph->RemoveEmptyRegsts();
  task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();
  auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
  if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); }
  task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
  task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });

  // Step4: put infomation from task_gph into plan.
  const int64_t node_num = task_gph->node_num();
  const int64_t cpu_num = std::thread::hardware_concurrency();
  const int64_t thread_pool_size = std::min(node_num, cpu_num);
  BlockingCounter counter(node_num);
  std::mutex mtx;
  ThreadPool thread_pool(thread_pool_size);
  task_gph->ForEachNode([&](TaskNode* task_node) {
    thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() {
      if (!task_node->IsMeaningLess()) {
        TaskProto task_proto;
        task_node->ToProto(&task_proto);
        {
          std::unique_lock<std::mutex> guard(mtx);
          if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
              || task_node->GetTaskType() == kAcc) {
            CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
          }
          plan->mutable_task()->Add(std::move(task_proto));
        }  // guard(mtx)
      }
      counter.Decrease();
    } /* thread_pool.AddWork */);
  } /* task_gph->ForEachNode */);
  counter.WaitUntilCntEqualZero();
  // NOTE(levi): release task_gph here to decrise memory peak.
  task_gph.reset();

  // Step5: post-process for plan and delete Global<OpGraph>.
  auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf();
  (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf();
  // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl
  IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable);
  PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan);
  Global<OpGraph>::Delete();
}
           

MainJob

圖檔來源:https://zhuanlan.zhihu.com/p/337851255

MainJob 的作用是什麼呢?

  • 負責和 Python 互動。從 Python 發送 Job ID 過來,然後啟動 Job。Wait and send ids,就是等待 Python 發 id 過來,然後将 id 發送給 Reentrant lock。這個可重入鎖主要用于控制并發,如果 Job ID 對應的 Job 和正在執行的 Job 沒有沖突,那麼就會發送給 Case,接着由 Case 分發,然後啟動 Job。如果 Job ID 和正在執行的 Job 沖突了,那麼這個 ID 就在 Reentrant lock 中等待。當一個 Job 執行完畢之後,就會通過 Esac 節點,發送 Job ID 回去給 Reentrant lock,然後 Reentrant lock 會放行可以執行的 Job。
OneFlow: 從 Job 到 Plan

生成、編譯、連結

MainJob 如何來的呢?主要有三個步驟:生成、編譯、連結。

生成過程

  • 設定 ParallelConf,之後每個 Op 都是用這個 Conf
  • ReentrantLock Op
  • Case Op
  • 為每個臨界區生成:source tick
  • 為每個臨界區在每台裝置上生成 identity tick,callback,sink tick
  • Esac Op
// oneflow/core/job/oneflow.cpp: 457
Maybe<ReentrantLockBackEdge> MakeMainJobComponent(
    const std::string& wait_and_send_ids_lbn, const Range& machine_id_range,
    JobBuilder* job_builder, std::vector<std::map<int64_t, std::string>>* identity_tick_op_names,
    std::vector<std::map<int64_t, std::string>>* cb_sink_tick_op_names) {
  ParallelConf parallel_conf;
  parallel_conf.set_device_tag("cpu");
  parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id_range.begin()) + ":0");
  auto lock_back_edge = std::make_shared<ReentrantLockBackEdge>();
  OperatorConf reentrant_lock_op_conf;
  {
    lock_back_edge->reentrant_lock_op_name =
        std::string("System-Main-ReentrantLock_") + NewUniqueId();
    reentrant_lock_op_conf.set_name(lock_back_edge->reentrant_lock_op_name);
    auto* reentrant_lock_conf = reentrant_lock_op_conf.mutable_reentrant_lock_conf();
    reentrant_lock_conf->set_start(wait_and_send_ids_lbn);
    // ibn "end" is set after plan generated because we don't like cycle in job
    reentrant_lock_conf->set_out("out");
    Global<CriticalSectionDesc>::Get()->DumpCriticalSectionId2IntersectinIds(
        reentrant_lock_conf->mutable_lock_id2intersecting_lock_ids());
    JUST(job_builder->AddOp(parallel_conf, reentrant_lock_op_conf));
  }
  // critical section case op conf
  OperatorConf cs_case_op_conf;
  {
    cs_case_op_conf.set_name(std::string("System-Main-Case_") + NewUniqueId());
    auto* cs_case_conf = cs_case_op_conf.mutable_case_conf();
    cs_case_conf->set_in(reentrant_lock_op_conf.name() + "/out");
    FOR_RANGE(int64_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
      cs_case_conf->add_out(GenRepeatedBn("out", i));
    }
    JUST(job_builder->AddOp(parallel_conf, cs_case_op_conf));
  }
  const int64_t num_critial_sections = Global<CriticalSectionDesc>::Get()->CriticalSectionNum();
  std::vector<std::string> snk_tick_op_names;
  FOR_RANGE(int64_t, i, 0, num_critial_sections) {
    // source tick
    OperatorConf src_tick_op_conf;
    {
      std::string name_prefix = "System-Main-SourceTick_CriticalSection_";
      src_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId());
      auto* src_tick_conf = src_tick_op_conf.mutable_tick_conf();
      src_tick_conf->add_tick(cs_case_op_conf.name() + "/" + GenRepeatedBn("out", i));
      src_tick_conf->set_out("out");
      JUST(job_builder->AddOp(parallel_conf, src_tick_op_conf));
    }

    auto* cur_cb_sink_tick_op_names = &cb_sink_tick_op_names->at(i);
    for (int64_t machine_id = machine_id_range.begin(); machine_id < machine_id_range.end();
         ++machine_id) {
      // identity tick
      OperatorConf identity_tick_op_conf;
      {
        std::string name_prefix = "System-Main-Tick_CriticalSection_";
        identity_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId());
        auto* identity_tick_conf = identity_tick_op_conf.mutable_tick_conf();
        identity_tick_conf->add_tick(src_tick_op_conf.name() + "/out");
        identity_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, identity_tick_op_conf));
        auto* cur_id_tick_op_names = &identity_tick_op_names->at(i);
        CHECK_OR_RETURN(
            cur_id_tick_op_names->emplace(machine_id, identity_tick_op_conf.name()).second);
      }
      // callback
      {
        OperatorConf cb_sink_tick_op_conf;
        std::string name_prefix = "System-Main-CallbackSinkTick_";
        cb_sink_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());
        auto* cb_sink_tick_conf = cb_sink_tick_op_conf.mutable_sink_tick_conf();
        cb_sink_tick_conf->add_tick(identity_tick_op_conf.name() + "/out");
        cb_sink_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, cb_sink_tick_op_conf));
        CHECK_OR_RETURN(
            cur_cb_sink_tick_op_names->emplace(machine_id, cb_sink_tick_op_conf.name()).second);
      }
      // sink tick
      {
        OperatorConf snk_tick_op_conf;
        std::string name_prefix = "System-Main-SinkTick_CriticalSection_";
        snk_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());
        auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf();
        snk_tick_conf->add_tick(identity_tick_op_conf.name() + "/out");
        snk_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf));
        snk_tick_op_names.push_back(snk_tick_op_conf.name());
      }
    }
  }
  // critical section esac op conf
  OperatorConf cs_esac_op_conf;
  {
    cs_esac_op_conf.set_name(std::string("System-Main-Esac_") + NewUniqueId());
    // cs_esac_op_conf.set_pass_tag("main");
    auto* cs_esac_conf = cs_esac_op_conf.mutable_esac_conf();
    for (const auto& snk_tick_op_name : snk_tick_op_names) {
      cs_esac_conf->add_in(snk_tick_op_name + "/out");
    }
    cs_esac_conf->set_out("out");
    cs_esac_conf->set_data_type(DataType::kInt32);
    JUST(job_builder->AddOp(parallel_conf, cs_esac_op_conf));
  }
  lock_back_edge->critical_section_sink_lbi.set_op_name(cs_esac_op_conf.name());
  lock_back_edge->critical_section_sink_lbi.set_blob_name("out");
  return lock_back_edge;
}
           

編譯

// oneflow/core/job/oneflow.cpp: 732
Maybe<void> CompileMainJob(Job* main_job, const std::vector<ReentrantLockBackEdge>& lock_back_edges,
                           int64_t job_id, Plan* main_plan) {
  CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
  {
    auto scope = std::make_unique<GlobalJobDescScope>(main_job->job_conf(), job_id);
    JUST(CompileCurJobOnMaster(main_job, main_plan, false));
  }
  for (const auto& lock_back_edge : lock_back_edges) {
    JUST(ConnectCriticalSectionEndToReentrantLockEnd(main_plan, lock_back_edge));
  }
  return Maybe<void>::Ok();
}
           

連結

// oneflow/core/job/oneflow.cpp: 306
void LinkMainPlan(Plan* plan, Plan&& main_plan,
                  const std::vector<std::map<int64_t, std::string>>& identity_tick_op_names) {
  std::function<bool(const TaskProto*)> IsInterfaceTickTockTask;
  {
    auto task_ids = std::make_shared<HashSet<int64_t>>();
    for (const auto& task : main_plan.task()) {
      if (task.task_type() == TaskType::kTick) { CHECK(task_ids->emplace(task.task_id()).second); }
    }
    IsInterfaceTickTockTask = [task_ids, plan](const TaskProto* task) {
      if (task_ids->find(task->task_id()) != task_ids->end()) { return true; }
      if (task->exec_sequence().exec_node_size() != 1) { return false; }
      const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
      OperatorConf::OpTypeCase op_type_case =
          PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().op_type_case();
      return op_type_case == OperatorConf::kSourceTickConf
             || op_type_case == OperatorConf::kSinkTickConf;
    };
  }
  MergePlan(plan, std::move(main_plan));
  HashMap<std::string, TaskProto*> sole_tick_op_name2sole_task;
  FOR_RANGE(int64_t, i, 0, plan->task_size()) {
    TaskProto* task = plan->mutable_task(i);
    if (IsInterfaceTickTockTask(task) == false) { continue; }
    const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
    const auto& op_name =
        PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();
    CHECK(sole_tick_op_name2sole_task.emplace(op_name, task).second);
  }
  auto TaskProto4TaskId = PlanUtil::MakeGetterTaskProto4TaskId(*plan);
  const auto& process_ranks = Global<ResourceDesc, ForSession>::Get()->process_ranks();
  FOR_RANGE(int32_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
    const CriticalSection& cs = Global<CriticalSectionDesc>::Get()->GetCriticalSection(i);
    for (int64_t machine_id : process_ranks) {
      TaskProto* identity_tick =
          sole_tick_op_name2sole_task.at(identity_tick_op_names.at(i).at(machine_id));
      LinkTickTaskProto(
          plan, identity_tick,
          sole_tick_op_name2sole_task.at(cs.machine_id2source_tick_op_name().at(machine_id)),
          sole_tick_op_name2sole_task.at(cs.machine_id2sink_tick_op_name().at(machine_id)));
    }
  }
  {
    // erase source_tick task_proto
    HashSet<std::string> source_tick_op_names;
    FOR_RANGE(int32_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
      const CriticalSection& cs = Global<CriticalSectionDesc>::Get()->GetCriticalSection(i);
      for (int64_t machine_id : process_ranks) {
        const auto& src_tick_op_name = cs.machine_id2source_tick_op_name().at(machine_id);
        CHECK(source_tick_op_names.emplace(src_tick_op_name).second);
      }
    }
    Erase<PbRpf<TaskProto>>(*plan->mutable_task(), [&](const TaskProto& task) {
      if (task.task_type() == TaskType::kSourceTick) {
        CHECK(task.exec_sequence().exec_node_size() == 1);
        const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf();
        const auto& op_conf = PlanUtil::GetOpAttribute(plan, task.job_id(), kernel_conf).op_conf();
        CHECK(op_conf.has_source_tick_conf());
        CHECK(source_tick_op_names.find(op_conf.name()) != source_tick_op_names.end());
        return true;
      } else {
        return false;
      }
    });
  }
}