JStarCraft AI


License Total lines Codacy Badge

希望路过的同学,顺手给JStarCraft框架点个Star,算是对作者的一种鼓励吧!


JStarCraft AI是一个机器学习的轻量级框架.遵循Apache 2.0协议.

在学术界,绝大多数研究人员使用的编程语言是Python.

在工业界,绝大多数开发人员使用的编程语言是Java.

JStarCraft AI是一个基于Java语言的机器学习工具包,由一系列的数据结构,算法和模型组成.

目标是作为在学术界与工业界从事机器学习研发的相关人员之间的桥梁.普及机器学习在Java领域的应用.

作者 洪钊桦
E-mail [email protected], [email protected]

JStarCraft AI架构

JStarCraft AI框架各个模块之间的关系: ai


JStarCraft AI特性


JStarCraft AI教程

Maven依赖

<dependency>
    <groupId>com.jstarcraft</groupId>
    <artifactId>ai</artifactId>
    <version>1.0</version>
</dependency>

Gradle依赖

compile group: 'com.jstarcraft', name: 'ai', version: '1.0'

设置CPU环境

<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>

设置GPU环境

<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-9.0-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-9.1-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-9.2-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-10.0-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-10.1-platform</artifactId>
    <version>1.0.0-beta3</version>
</dependency>

使用环境上下文

// 获取默认环境上下文
EnvironmentContext context = EnvironmentContext.getContext();
// 在环境上下文中执行任务
Future<?> task = context.doTask(() - > {
    int dimension = 10;
    MathMatrix leftMatrix = getRandomMatrix(dimension);
    MathMatrix rightMatrix = getRandomMatrix(dimension);
    MathMatrix dataMatrix = getZeroMatrix(dimension);
    dataMatrix.dotProduct(leftMatrix, false, rightMatrix, true, MathCalculator.PARALLEL);
});

数据表示

用户(User) 旧手机类型(Item) 新手机类型(Item) 评分(Score)
Google Fan Android Android 3
Google Fan Android IOS 1
Google Fan IOS Android 5
Apple Fan IOS IOS 3
Apple Fan Android IOS 5
Apple Fan IOS Android 1
定性(User) 定性(Item) 定性(Item) 定量(Score)
0 0 0 3
0 0 1 1
0 1 0 5
1 1 1 3
1 0 1 5
1 1 0 1

数据转换

数据转换器(DataConverter)负责各种各样的格式转换为JStarCraft AI框架能够处理的数据模块(DataModule).

JStarCraft AI框架各个转换器与其它系统之间的关系:

converter

// 定性属性
Map<String, Class<?>> qualityDifinitions = new HashMap<>();
qualityDifinitions.put("user", String.class);
qualityDifinitions.put("item", String.class);

// 定量属性
Map<String, Class<?>> quantityDifinitions = new HashMap<>();
quantityDifinitions.put("score", float.class);
DataSpace space = new DataSpace(qualityDifinitions, quantityDifinitions);
TreeMap<Integer, String> configuration = new TreeMap<>();
configuration.put(1, "user");
configuration.put(3, "item");
configuration.put(4, "score");
DataModule module = space.makeDenseModule("module", configuration, 1000);

JStarCraft AI框架兼容的格式

// ARFF转换器
ArffConverter converter = new ArffConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取流
File file = new File(this.getClass().getResource("module.arff").toURI());
InputStream stream = new FileInputStream(file);

// 转换数据
int count = converter.convert(module, stream, null, null, null);
// CSV转换器
CsvConverter converter = new CsvConverter(',', space.getQualityAttributes(), space.getQuantityAttributes());

// 获取流
File file = new File(this.getClass().getResource("module.csv").toURI());
InputStream stream = new FileInputStream(file);

// 转换数据
int count = converter.convert(module, stream, null, null, null);
// JSON转换器
JsonConverter converter = new JsonConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取流
File file = new File(this.getClass().getResource("module.json").toURI());
InputStream stream = new FileInputStream(file);

// 转换数据
int count = converter.convert(module, stream, null, null, null);
// HQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取游标
String selectDataHql = "select data.user, data.leftItem, data.rightItem, data.score from MockData data";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataHql);
ScrollableResults iterator = query.scroll();

// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();
// SQL转换器
QueryConverter converter = new QueryConverter(space.getQualityAttributes(), space.getQuantityAttributes());

// 获取游标
String selectDataSql = "select user, leftItem, rightItem, score from MockData";
Session session = sessionFactory.openSession();
Query query = session.createQuery(selectDataSql);
ScrollableResults iterator = query.scroll();

// 转换数据
int count = converter.convert(module, iterator, null, null, null);
session.close();

数据处理


评估指标

排序指标

评分指标