/** * * Copyright (c) 2017 ytk-learn https://github.com/yuantiku * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ package com.fenbi.ytklearn.worker; import com.fenbi.ytklearn.dataflow.*; import com.fenbi.ytklearn.fs.FileSystemFactory; import com.fenbi.ytklearn.fs.IFileSystem; import com.fenbi.ytklearn.optimizer.OptimizerFactory; import com.fenbi.ytklearn.operation.ITrainOperation; import com.fenbi.ytklearn.operation.TrainOperationFactory; import com.fenbi.mp4j.exception.Mp4jException; import com.fenbi.mp4j.comm.ThreadCommSlave; import com.fenbi.ytklearn.utils.CheckUtils; import com.typesafe.config.Config; import com.typesafe.config.ConfigFactory; import com.typesafe.config.ConfigValue; import com.typesafe.config.ConfigValueFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.Serializable; import java.net.URI; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; /** * @author xialong */ public class TrainWorker implements Serializable { public static final Logger LOG = LoggerFactory.getLogger(TrainWorker.class); protected String modelName; protected String configPath; protected String configFile; protected String pyTransformScript; protected boolean needPyTransform; protected String loginName; protected String hostName; protected int hostPort; protected int threadNum; final Map<String, Object> customParamsMap = new HashMap<>(); public TrainWorker( String modelName, String configPath, String configFile, String pyTransformScript, boolean needPyTransform, String loginName, String hostName, int hostPort, int threadNum ) throws Exception { this.modelName = modelName; this.configPath = configPath; this.configFile = configFile; this.pyTransformScript = pyTransformScript; this.needPyTransform = needPyTransform; this.loginName = loginName; this.hostName = hostName; this.hostPort = hostPort; this.threadNum = threadNum; LOG.info("configFile:" + configFile); LOG.info("configPath:" + configPath); LOG.info("pyTransformScript:" + pyTransformScript); LOG.info("loginName:" + loginName); LOG.info("hostName:" + hostName + ", hostPort:" + hostPort); LOG.info("threadNum:" + threadNum); LOG.info("modelName:" + modelName); } public String getTrainDataPath() { String configRealPath = (new File(configFile).exists()) ? configFile : configPath; File realFile = new File(configRealPath); CheckUtils.check(realFile.exists(), "config file(%s) doesn't exist!", configRealPath); Config config = ConfigFactory.parseFile(realFile); config = updateConfigWithCustom(config); return config.getString("data.train.data_path"); } public String getURI() { String configRealPath = (new File(configFile).exists()) ? configFile : configPath; File realFile = new File(configRealPath); CheckUtils.check(realFile.exists(), "config file(%s) doesn't exist!", configRealPath); Config config = ConfigFactory.parseFile(realFile); config = updateConfigWithCustom(config); return config.getString("fs_scheme"); } private Config updateConfigWithCustom(Config config) { for (Map.Entry<String, Object> entry : customParamsMap.entrySet()) { config = config.withValue(entry.getKey(), ConfigValueFactory.fromAnyRef(entry.getValue())); } return config; } public void emptyCustomParams() { customParamsMap.clear(); } public void setCustomParam(String key, Object value) { customParamsMap.put(key, value); } public boolean train(List<Iterator<String>> trainDatas, List<Iterator<String>> testDatas) { long start = System.currentTimeMillis(); int errorCode = 0; ThreadCommSlave comm = null; try { comm = new ThreadCommSlave(loginName, threadNum, hostName, hostPort); File file = new File(configFile); CheckUtils.check(file.exists(), "config file(%s) doesn't exist!", configFile); Config config = ConfigFactory.parseFile(file); config = updateConfigWithCustom(config); comm.info("################ parameters ################"); for (Map.Entry<String, ConfigValue> entry : config.entrySet()) { comm.info(entry.getKey() + "=" + entry.getValue()); } String uri = config.getString("fs_scheme"); LOG.info("file system uri:" + uri + ", URI:" + new URI(uri) + ", URI tostring:" + (new URI(uri)).toString()); IFileSystem fs = FileSystemFactory.createFileSystem(new URI(uri)); DataFlow dataFlow = DataFlowFactory.createDataFlow(modelName, fs, config, comm, threadNum, needPyTransform, pyTransformScript); dataFlow.init(); dataFlow.loadFlow(trainDatas, testDatas); dataFlow.ready(); long beforeTrainCost = System.currentTimeMillis() - start; ITrainOperation trainOperation = TrainOperationFactory.createTrainOperation(modelName); final ThreadCommSlave finalComm = comm; // 开始训练 Thread []threads = new Thread[threadNum]; for (int t = 0; t < threadNum; t++) { final int tidx = t; threads[t] = new Thread(t + "") { @Override public void run() { finalComm.setThreadId(tidx); try { if (!dataFlow.isReady()) { throw new Exception("data flow is not ready!"); } trainOperation.operate(dataFlow, OptimizerFactory.createOptimizer( modelName, dataFlow, tidx), finalComm, tidx ); } catch (Exception e) { try { finalComm.exception(e); finalComm.close(1); } catch (Mp4jException e1) { LOG.error("comm send exception message error!", e); } System.exit(1); } } }; threads[t].start(); } for (int t = 0; t < threadNum; t++) { threads[t].join(); } long totalCost = System.currentTimeMillis() - start; long trainCost = totalCost - beforeTrainCost; comm.info(String.format("\nTrain cost details:\n%-19s%9.2fs\n%-19s%9.2fs\n%-19s%9.2fs\n", "LoadDataFlow:", beforeTrainCost / 1000., "PreprocessAndTrain:", trainCost / 1000., "Total:", totalCost / 1000.)); } catch (Exception e) { errorCode = 1; LOG.error("existed exception!", e); if (comm != null) { try { comm.exception(e); } catch (Mp4jException e1) { LOG.error("comm send exception message error!", e); } } } finally { try { if (comm != null) { comm.close(errorCode); } } catch (Mp4jException e) { errorCode = 1; LOG.error("comm close exception!", e); } } return errorCode == 0; } }