import React from "react"; import { ScatterChart, Scatter, XAxis, YAxis, ZAxis, CartesianGrid, Tooltip, Legend, Label, ResponsiveContainer, } from "recharts"; import { useModelState } from "state/ModelState"; import CircularProgress from "@material-ui/core/CircularProgress"; import { invoke } from "js-ml/knn"; import { modelLinear } from "codegen/regressionCode"; import { model } from "codegen/knnCode"; import LoadingComponent from "./LoadingComponent"; const fillColors = ["#3493fa", "#f3bb43", "#e06aa6"]; const shapeTypes = ["star", "circle", "triangle"]; const createPlotData = (state) => { const { knn_column1_index, knn_column2_index, knn_test_data, knn_result_labels, } = state; const data = {}; knn_test_data.forEach((dataRow, index) => { const result_category = knn_result_labels[index]; if (!data.hasOwnProperty(result_category)) { data[result_category] = []; } data[result_category].push({ x: dataRow[knn_column1_index], y: dataRow[knn_column2_index], }); }); return data; }; export default function PlotKNN() { const { model_state } = useModelState(); const data = createPlotData(model_state); const columns = model_state.knn_columns; const columnMap = model_state.knn_columns_map; const xAxisColumn = columnMap[columns[model_state.knn_column1_index]]; const yAxisColumn = columnMap[columns[model_state.knn_column2_index]]; return ( <> {!model_state.viz_loading ? ( <ResponsiveContainer className="graph-wrapper" width="100%" height="100%" > <ScatterChart margin={{ top: 20, right: 20, bottom: 20, left: 20, }} > <CartesianGrid /> <XAxis type="number" dataKey="x" name={xAxisColumn} unit={ model_state.knn_column_units ? model_state.knn_column_units[model_state.knn_column1_index] : "" } > <Label value={xAxisColumn} position="insideBottom" offset={-12} /> </XAxis> <YAxis type="number" dataKey="y" name={yAxisColumn} unit={ model_state.knn_column_units ? model_state.knn_column_units[model_state.knn_column2_index] : "" } > <Label value={yAxisColumn} angle={-90} position="insideLeft" /> </YAxis> <Tooltip cursor={{ strokeDasharray: "3 3" }} /> <Legend verticalAlign="top" height={36} /> {model_state.knn_labels.map((label, index) => ( <Scatter name={label} data={data[index]} fill={fillColors[index]} shape={shapeTypes[index]} key={index} /> ))} </ScatterChart> </ResponsiveContainer> ) : ( <LoadingComponent /> )} </> ); }