Merge "K-Means color clustering" into oc-dr1-dev
This commit is contained in:
@@ -27,6 +27,7 @@ import android.os.Parcelable;
|
||||
import android.util.Size;
|
||||
|
||||
import com.android.internal.graphics.palette.Palette;
|
||||
import com.android.internal.graphics.palette.VariationalKMeansQuantizer;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
@@ -142,6 +143,8 @@ public final class WallpaperColors implements Parcelable {
|
||||
|
||||
final Palette palette = Palette
|
||||
.from(bitmap)
|
||||
.setQuantizer(new VariationalKMeansQuantizer())
|
||||
.maximumColorCount(5)
|
||||
.clearFilters()
|
||||
.resizeBitmapArea(MAX_WALLPAPER_EXTRACTION_AREA)
|
||||
.generate();
|
||||
|
||||
@@ -61,7 +61,7 @@ import com.android.internal.graphics.palette.Palette.Swatch;
|
||||
* This means that the color space is divided into distinct colors, rather than representative
|
||||
* colors.
|
||||
*/
|
||||
final class ColorCutQuantizer {
|
||||
final class ColorCutQuantizer implements Quantizer {
|
||||
|
||||
private static final String LOG_TAG = "ColorCutQuantizer";
|
||||
private static final boolean LOG_TIMINGS = false;
|
||||
@@ -73,22 +73,22 @@ final class ColorCutQuantizer {
|
||||
private static final int QUANTIZE_WORD_WIDTH = 5;
|
||||
private static final int QUANTIZE_WORD_MASK = (1 << QUANTIZE_WORD_WIDTH) - 1;
|
||||
|
||||
final int[] mColors;
|
||||
final int[] mHistogram;
|
||||
final List<Swatch> mQuantizedColors;
|
||||
final TimingLogger mTimingLogger;
|
||||
final Palette.Filter[] mFilters;
|
||||
int[] mColors;
|
||||
int[] mHistogram;
|
||||
List<Swatch> mQuantizedColors;
|
||||
TimingLogger mTimingLogger;
|
||||
Palette.Filter[] mFilters;
|
||||
|
||||
private final float[] mTempHsl = new float[3];
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
* Execute color quantization.
|
||||
*
|
||||
* @param pixels histogram representing an image's pixel data
|
||||
* @param maxColors The maximum number of colors that should be in the result palette.
|
||||
* @param filters Set of filters to use in the quantization stage
|
||||
*/
|
||||
ColorCutQuantizer(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
|
||||
public void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
|
||||
mTimingLogger = LOG_TIMINGS ? new TimingLogger(LOG_TAG, "Creation") : null;
|
||||
mFilters = filters;
|
||||
|
||||
@@ -160,7 +160,7 @@ final class ColorCutQuantizer {
|
||||
/**
|
||||
* @return the list of quantized colors
|
||||
*/
|
||||
List<Swatch> getQuantizedColors() {
|
||||
public List<Swatch> getQuantizedColors() {
|
||||
return mQuantizedColors;
|
||||
}
|
||||
|
||||
|
||||
@@ -613,6 +613,8 @@ public final class Palette {
|
||||
private final List<Palette.Filter> mFilters = new ArrayList<>();
|
||||
private Rect mRegion;
|
||||
|
||||
private Quantizer mQuantizer;
|
||||
|
||||
/**
|
||||
* Construct a new {@link Palette.Builder} using a source {@link Bitmap}
|
||||
*/
|
||||
@@ -725,6 +727,18 @@ public final class Palette {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a specific quantization algorithm. {@link ColorCutQuantizer} will
|
||||
* be used if unspecified.
|
||||
*
|
||||
* @param quantizer Quantizer implementation.
|
||||
*/
|
||||
@NonNull
|
||||
public Palette.Builder setQuantizer(Quantizer quantizer) {
|
||||
mQuantizer = quantizer;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a region of the bitmap to be used exclusively when calculating the palette.
|
||||
* <p>This only works when the original input is a {@link Bitmap}.</p>
|
||||
@@ -818,17 +832,19 @@ public final class Palette {
|
||||
}
|
||||
|
||||
// Now generate a quantizer from the Bitmap
|
||||
final ColorCutQuantizer quantizer = new ColorCutQuantizer(
|
||||
getPixelsFromBitmap(bitmap),
|
||||
mMaxColors,
|
||||
mFilters.isEmpty() ? null : mFilters.toArray(new Palette.Filter[mFilters.size()]));
|
||||
if (mQuantizer == null) {
|
||||
mQuantizer = new ColorCutQuantizer();
|
||||
}
|
||||
mQuantizer.quantize(getPixelsFromBitmap(bitmap),
|
||||
mMaxColors, mFilters.isEmpty() ? null :
|
||||
mFilters.toArray(new Palette.Filter[mFilters.size()]));
|
||||
|
||||
// If created a new bitmap, recycle it
|
||||
if (bitmap != mBitmap) {
|
||||
bitmap.recycle();
|
||||
}
|
||||
|
||||
swatches = quantizer.getQuantizedColors();
|
||||
swatches = mQuantizer.getQuantizedColors();
|
||||
|
||||
if (logger != null) {
|
||||
logger.addSplit("Color quantization completed");
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright (C) 2017 The Android Open Source Project
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License
|
||||
*/
|
||||
|
||||
package com.android.internal.graphics.palette;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Definition of an algorithm that receives pixels and outputs a list of colors.
|
||||
*/
|
||||
public interface Quantizer {
|
||||
void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters);
|
||||
List<Palette.Swatch> getQuantizedColors();
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
/*
|
||||
* Copyright (C) 2017 The Android Open Source Project
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License
|
||||
*/
|
||||
|
||||
package com.android.internal.graphics.palette;
|
||||
|
||||
import android.util.Log;
|
||||
|
||||
import com.android.internal.graphics.ColorUtils;
|
||||
import com.android.internal.ml.clustering.KMeans;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* A quantizer that uses k-means
|
||||
*/
|
||||
public class VariationalKMeansQuantizer implements Quantizer {
|
||||
|
||||
private static final String TAG = "KMeansQuantizer";
|
||||
private static final boolean DEBUG = false;
|
||||
|
||||
/**
|
||||
* Clusters closer than this value will me merged.
|
||||
*/
|
||||
private final float mMinClusterSqDistance;
|
||||
|
||||
/**
|
||||
* K-means can get stuck in local optima, this can be avoided by
|
||||
* repeating it and getting the "best" execution.
|
||||
*/
|
||||
private final int mInitializations;
|
||||
|
||||
/**
|
||||
* Initialize KMeans with a fixed random state to have
|
||||
* consistent results across multiple runs.
|
||||
*/
|
||||
private final KMeans mKMeans = new KMeans(new Random(0), 30, 0);
|
||||
|
||||
private List<Palette.Swatch> mQuantizedColors;
|
||||
|
||||
public VariationalKMeansQuantizer() {
|
||||
this(0.25f /* cluster distance */);
|
||||
}
|
||||
|
||||
public VariationalKMeansQuantizer(float minClusterDistance) {
|
||||
this(minClusterDistance, 1 /* initializations */);
|
||||
}
|
||||
|
||||
public VariationalKMeansQuantizer(float minClusterDistance, int initializations) {
|
||||
mMinClusterSqDistance = minClusterDistance * minClusterDistance;
|
||||
mInitializations = initializations;
|
||||
}
|
||||
|
||||
/**
|
||||
* K-Means quantizer.
|
||||
*
|
||||
* @param pixels Pixels to quantize.
|
||||
* @param maxColors Maximum number of clusters to extract.
|
||||
* @param filters Colors that should be ignored
|
||||
*/
|
||||
@Override
|
||||
public void quantize(int[] pixels, int maxColors, Palette.Filter[] filters) {
|
||||
// Start by converting all colors to HSL.
|
||||
// HLS is way more meaningful for clustering than RGB.
|
||||
final float[] hsl = {0, 0, 0};
|
||||
final float[][] hslPixels = new float[pixels.length][3];
|
||||
for (int i = 0; i < pixels.length; i++) {
|
||||
ColorUtils.colorToHSL(pixels[i], hsl);
|
||||
// Normalize hue so all values go from 0 to 1.
|
||||
hslPixels[i][0] = hsl[0] / 360f;
|
||||
hslPixels[i][1] = hsl[1];
|
||||
hslPixels[i][2] = hsl[2];
|
||||
}
|
||||
|
||||
final List<KMeans.Mean> optimalMeans = getOptimalKMeans(maxColors, hslPixels);
|
||||
|
||||
// Ideally we should run k-means again to merge clusters but it would be too expensive,
|
||||
// instead we just merge all clusters that are closer than a threshold.
|
||||
for (int i = 0; i < optimalMeans.size(); i++) {
|
||||
KMeans.Mean current = optimalMeans.get(i);
|
||||
float[] currentCentroid = current.getCentroid();
|
||||
for (int j = i + 1; j < optimalMeans.size(); j++) {
|
||||
KMeans.Mean compareTo = optimalMeans.get(j);
|
||||
float[] compareToCentroid = compareTo.getCentroid();
|
||||
float sqDistance = KMeans.sqDistance(currentCentroid, compareToCentroid);
|
||||
// Merge them
|
||||
if (sqDistance < mMinClusterSqDistance) {
|
||||
optimalMeans.remove(compareTo);
|
||||
current.getItems().addAll(compareTo.getItems());
|
||||
for (int k = 0; k < currentCentroid.length; k++) {
|
||||
currentCentroid[k] += (compareToCentroid[k] - currentCentroid[k]) / 2.0;
|
||||
}
|
||||
j--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert data to final format, de-normalizing the hue.
|
||||
mQuantizedColors = new ArrayList<>();
|
||||
for (KMeans.Mean mean : optimalMeans) {
|
||||
if (mean.getItems().size() == 0) {
|
||||
continue;
|
||||
}
|
||||
float[] centroid = mean.getCentroid();
|
||||
mQuantizedColors.add(new Palette.Swatch(new float[]{
|
||||
centroid[0] * 360f,
|
||||
centroid[1],
|
||||
centroid[2]
|
||||
}, mean.getItems().size()));
|
||||
}
|
||||
}
|
||||
|
||||
private List<KMeans.Mean> getOptimalKMeans(int k, float[][] inputData) {
|
||||
List<KMeans.Mean> optimal = null;
|
||||
double optimalScore = -Double.MAX_VALUE;
|
||||
int runs = mInitializations;
|
||||
while (runs > 0) {
|
||||
if (DEBUG) {
|
||||
Log.d(TAG, "k-means run: " + runs);
|
||||
}
|
||||
List<KMeans.Mean> means = mKMeans.predict(k, inputData);
|
||||
double score = KMeans.score(means);
|
||||
if (optimal == null || score > optimalScore) {
|
||||
if (DEBUG) {
|
||||
Log.d(TAG, "\tnew optimal score: " + score);
|
||||
}
|
||||
optimalScore = score;
|
||||
optimal = means;
|
||||
}
|
||||
runs--;
|
||||
}
|
||||
|
||||
return optimal;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Palette.Swatch> getQuantizedColors() {
|
||||
return mQuantizedColors;
|
||||
}
|
||||
}
|
||||
243
core/java/com/android/internal/ml/clustering/KMeans.java
Normal file
243
core/java/com/android/internal/ml/clustering/KMeans.java
Normal file
@@ -0,0 +1,243 @@
|
||||
/*
|
||||
* Copyright (C) 2017 The Android Open Source Project
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License
|
||||
*/
|
||||
|
||||
package com.android.internal.ml.clustering;
|
||||
|
||||
import android.annotation.NonNull;
|
||||
import android.util.Log;
|
||||
|
||||
import com.android.internal.annotations.VisibleForTesting;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* Simple K-Means implementation
|
||||
*/
|
||||
public class KMeans {
|
||||
|
||||
private static final boolean DEBUG = false;
|
||||
private static final String TAG = "KMeans";
|
||||
private final Random mRandomState;
|
||||
private final int mMaxIterations;
|
||||
private float mSqConvergenceEpsilon;
|
||||
|
||||
public KMeans() {
|
||||
this(new Random());
|
||||
}
|
||||
|
||||
public KMeans(Random random) {
|
||||
this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
|
||||
}
|
||||
public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
|
||||
mRandomState = random;
|
||||
mMaxIterations = maxIterations;
|
||||
mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs k-means on the input data (X) trying to find k means.
|
||||
*
|
||||
* K-Means is known for getting stuck into local optima, so you might
|
||||
* want to run it multiple time and argmax on {@link KMeans#score(List)}
|
||||
*
|
||||
* @param k The number of points to return.
|
||||
* @param inputData Input data.
|
||||
* @return An array of k Means, each representing a centroid and data points that belong to it.
|
||||
*/
|
||||
public List<Mean> predict(final int k, final float[][] inputData) {
|
||||
checkDataSetSanity(inputData);
|
||||
int dimension = inputData[0].length;
|
||||
|
||||
final ArrayList<Mean> means = new ArrayList<>();
|
||||
for (int i = 0; i < k; i++) {
|
||||
Mean m = new Mean(dimension);
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
m.mCentroid[j] = mRandomState.nextFloat();
|
||||
}
|
||||
means.add(m);
|
||||
}
|
||||
|
||||
// Iterate until we converge or run out of iterations
|
||||
boolean converged = false;
|
||||
for (int i = 0; i < mMaxIterations; i++) {
|
||||
converged = step(means, inputData);
|
||||
if (converged) {
|
||||
if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!converged && DEBUG) Log.d(TAG, "Did not converge");
|
||||
|
||||
return means;
|
||||
}
|
||||
|
||||
/**
|
||||
* Score calculates the inertia between means.
|
||||
* This can be considered as an E step of an EM algorithm.
|
||||
*
|
||||
* @param means Means to use when calculating score.
|
||||
* @return The score
|
||||
*/
|
||||
public static double score(@NonNull List<Mean> means) {
|
||||
double score = 0;
|
||||
final int meansSize = means.size();
|
||||
for (int i = 0; i < meansSize; i++) {
|
||||
Mean mean = means.get(i);
|
||||
for (int j = 0; j < meansSize; j++) {
|
||||
Mean compareTo = means.get(j);
|
||||
if (mean == compareTo) {
|
||||
continue;
|
||||
}
|
||||
double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
|
||||
score += distance;
|
||||
}
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public void checkDataSetSanity(float[][] inputData) {
|
||||
if (inputData == null) {
|
||||
throw new IllegalArgumentException("Data set is null.");
|
||||
} else if (inputData.length == 0) {
|
||||
throw new IllegalArgumentException("Data set is empty.");
|
||||
} else if (inputData[0] == null) {
|
||||
throw new IllegalArgumentException("Bad data set format.");
|
||||
}
|
||||
|
||||
final int dimension = inputData[0].length;
|
||||
final int length = inputData.length;
|
||||
for (int i = 1; i < length; i++) {
|
||||
if (inputData[i] == null || inputData[i].length != dimension) {
|
||||
throw new IllegalArgumentException("Bad data set format.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* K-Means iteration.
|
||||
*
|
||||
* @param means Current means
|
||||
* @param inputData Input data
|
||||
* @return True if data set converged
|
||||
*/
|
||||
private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
|
||||
|
||||
// Clean up the previous state because we need to compute
|
||||
// which point belongs to each mean again.
|
||||
for (int i = means.size() - 1; i >= 0; i--) {
|
||||
final Mean mean = means.get(i);
|
||||
mean.mClosestItems.clear();
|
||||
}
|
||||
for (int i = inputData.length - 1; i >= 0; i--) {
|
||||
final float[] current = inputData[i];
|
||||
final Mean nearest = nearestMean(current, means);
|
||||
nearest.mClosestItems.add(current);
|
||||
}
|
||||
|
||||
boolean converged = true;
|
||||
// Move each mean towards the nearest data set points
|
||||
for (int i = means.size() - 1; i >= 0; i--) {
|
||||
final Mean mean = means.get(i);
|
||||
if (mean.mClosestItems.size() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute the new mean centroid:
|
||||
// 1. Sum all all points
|
||||
// 2. Average them
|
||||
final float[] oldCentroid = mean.mCentroid;
|
||||
mean.mCentroid = new float[oldCentroid.length];
|
||||
for (int j = 0; j < mean.mClosestItems.size(); j++) {
|
||||
// Update each centroid component
|
||||
for (int p = 0; p < mean.mCentroid.length; p++) {
|
||||
mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < mean.mCentroid.length; j++) {
|
||||
mean.mCentroid[j] /= mean.mClosestItems.size();
|
||||
}
|
||||
|
||||
// We converged if the centroid didn't move for any of the means.
|
||||
if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
|
||||
converged = false;
|
||||
}
|
||||
}
|
||||
return converged;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public static Mean nearestMean(float[] point, List<Mean> means) {
|
||||
Mean nearest = null;
|
||||
float nearestDistance = Float.MAX_VALUE;
|
||||
|
||||
final int meanCount = means.size();
|
||||
for (int i = 0; i < meanCount; i++) {
|
||||
Mean next = means.get(i);
|
||||
// We don't need the sqrt when comparing distances in euclidean space
|
||||
// because they exist on both sides of the equation and cancel each other out.
|
||||
float nextDistance = sqDistance(point, next.mCentroid);
|
||||
if (nextDistance < nearestDistance) {
|
||||
nearest = next;
|
||||
nearestDistance = nextDistance;
|
||||
}
|
||||
}
|
||||
return nearest;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public static float sqDistance(float[] a, float[] b) {
|
||||
float dist = 0;
|
||||
final int length = a.length;
|
||||
for (int i = 0; i < length; i++) {
|
||||
dist += (a[i] - b[i]) * (a[i] - b[i]);
|
||||
}
|
||||
return dist;
|
||||
}
|
||||
|
||||
/**
|
||||
* Definition of a mean, contains a centroid and points on its cluster.
|
||||
*/
|
||||
public static class Mean {
|
||||
float[] mCentroid;
|
||||
final ArrayList<float[]> mClosestItems = new ArrayList<>();
|
||||
|
||||
public Mean(int dimension) {
|
||||
mCentroid = new float[dimension];
|
||||
}
|
||||
|
||||
public Mean(float ...centroid) {
|
||||
mCentroid = centroid;
|
||||
}
|
||||
|
||||
public float[] getCentroid() {
|
||||
return mCentroid;
|
||||
}
|
||||
|
||||
public List<float[]> getItems() {
|
||||
return mClosestItems;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
|
||||
+ mClosestItems.size() + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
20
tests/Internal/Android.mk
Normal file
20
tests/Internal/Android.mk
Normal file
@@ -0,0 +1,20 @@
|
||||
LOCAL_PATH:= $(call my-dir)
|
||||
include $(CLEAR_VARS)
|
||||
|
||||
LOCAL_USE_AAPT2 := true
|
||||
LOCAL_MODULE_TAGS := tests
|
||||
|
||||
LOCAL_PROTOC_OPTIMIZE_TYPE := nano
|
||||
|
||||
# Include some source files directly to be able to access package members
|
||||
LOCAL_SRC_FILES := $(call all-java-files-under, src)
|
||||
|
||||
LOCAL_JAVA_LIBRARIES := android.test.runner
|
||||
LOCAL_STATIC_JAVA_LIBRARIES := junit legacy-android-test android-support-test
|
||||
|
||||
LOCAL_CERTIFICATE := platform
|
||||
|
||||
LOCAL_PACKAGE_NAME := InternalTests
|
||||
LOCAL_COMPATIBILITY_SUITE := device-tests
|
||||
|
||||
include $(BUILD_PACKAGE)
|
||||
28
tests/Internal/AndroidManifest.xml
Normal file
28
tests/Internal/AndroidManifest.xml
Normal file
@@ -0,0 +1,28 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!--
|
||||
~ Copyright (C) 2017 The Android Open Source Project
|
||||
~
|
||||
~ Licensed under the Apache License, Version 2.0 (the "License");
|
||||
~ you may not use this file except in compliance with the License.
|
||||
~ You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
~ See the License for the specific language governing permissions and
|
||||
~ limitations under the License
|
||||
-->
|
||||
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.android.internal.tests">
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
|
||||
<application>
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation android:name="android.support.test.runner.AndroidJUnitRunner"
|
||||
android:targetPackage="com.android.internal.tests"
|
||||
android:label="Internal Tests" />
|
||||
</manifest>
|
||||
29
tests/Internal/AndroidTest.xml
Normal file
29
tests/Internal/AndroidTest.xml
Normal file
@@ -0,0 +1,29 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!--
|
||||
~ Copyright (C) 2017 The Android Open Source Project
|
||||
~
|
||||
~ Licensed under the Apache License, Version 2.0 (the "License");
|
||||
~ you may not use this file except in compliance with the License.
|
||||
~ You may obtain a copy of the License at
|
||||
~
|
||||
~ http://www.apache.org/licenses/LICENSE-2.0
|
||||
~
|
||||
~ Unless required by applicable law or agreed to in writing, software
|
||||
~ distributed under the License is distributed on an "AS IS" BASIS,
|
||||
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
~ See the License for the specific language governing permissions and
|
||||
~ limitations under the License
|
||||
-->
|
||||
<configuration description="Runs tests for internal classes/utilities.">
|
||||
<target_preparer class="com.android.tradefed.targetprep.TestAppInstallSetup">
|
||||
<option name="test-file-name" value="InternalTests.apk" />
|
||||
</target_preparer>
|
||||
|
||||
<option name="test-suite-tag" value="apct" />
|
||||
<option name="test-suite-tag" value="framework-base-presubmit" />
|
||||
<option name="test-tag" value="InternalTests" />
|
||||
<test class="com.android.tradefed.testtype.AndroidJUnitTest" >
|
||||
<option name="package" value="com.android.internal.tests" />
|
||||
<option name="runner" value="android.support.test.runner.AndroidJUnitRunner" />
|
||||
</test>
|
||||
</configuration>
|
||||
@@ -0,0 +1,155 @@
|
||||
/*
|
||||
* Copyright (C) 2017 The Android Open Source Project
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.android.internal.ml.clustering;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import android.annotation.SuppressLint;
|
||||
import android.support.test.filters.SmallTest;
|
||||
import android.support.test.runner.AndroidJUnit4;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
@SmallTest
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class KMeansTest {
|
||||
|
||||
// Error tolerance (epsilon)
|
||||
private static final double EPS = 0.01;
|
||||
|
||||
private KMeans mKMeans;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
// Setup with a random seed to have predictable results
|
||||
mKMeans = new KMeans(new Random(0), 30, 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getCheckDataSanityTest() {
|
||||
try {
|
||||
mKMeans.checkDataSetSanity(new float[][] {
|
||||
{0, 1, 2},
|
||||
{1, 2, 3}
|
||||
});
|
||||
} catch (IllegalArgumentException e) {
|
||||
Assert.fail("Valid data didn't pass sanity check");
|
||||
}
|
||||
|
||||
try {
|
||||
mKMeans.checkDataSetSanity(new float[][] {
|
||||
null,
|
||||
{1, 2, 3}
|
||||
});
|
||||
Assert.fail("Data has null items and passed");
|
||||
} catch (IllegalArgumentException e) {}
|
||||
|
||||
try {
|
||||
mKMeans.checkDataSetSanity(new float[][] {
|
||||
{0, 1, 2, 4},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Assert.fail("Data has invalid shape and passed");
|
||||
} catch (IllegalArgumentException e) {}
|
||||
|
||||
try {
|
||||
mKMeans.checkDataSetSanity(null);
|
||||
Assert.fail("Null data should throw exception");
|
||||
} catch (IllegalArgumentException e) {}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sqDistanceTest() {
|
||||
float a[] = {4, 10};
|
||||
float b[] = {5, 2};
|
||||
float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2));
|
||||
|
||||
assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void nearestMeanTest() {
|
||||
KMeans.Mean meanA = new KMeans.Mean(0, 1);
|
||||
KMeans.Mean meanB = new KMeans.Mean(1, 1);
|
||||
List<KMeans.Mean> means = Arrays.asList(meanA, meanB);
|
||||
|
||||
KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means);
|
||||
|
||||
assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB);
|
||||
}
|
||||
|
||||
@SuppressLint("DefaultLocale")
|
||||
@Test
|
||||
public void scoreTest() {
|
||||
List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f),
|
||||
new KMeans.Mean(0, 0.1f, 0.15f),
|
||||
new KMeans.Mean(0.1f, 0.2f, 0.1f));
|
||||
List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0),
|
||||
new KMeans.Mean(0, 0.5f, 0.5f),
|
||||
new KMeans.Mean(1, 0.9f, 0.9f));
|
||||
|
||||
double closeScore = KMeans.score(closeMeans);
|
||||
double farScore = KMeans.score(farMeans);
|
||||
assertTrue(String.format("Score of well distributed means should be greater than "
|
||||
+ "close means but got: %f, %f", farScore, closeScore), farScore > closeScore);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void predictTest() {
|
||||
float[] expectedCentroid1 = {1, 1, 1};
|
||||
float[] expectedCentroid2 = {0, 0, 0};
|
||||
float[][] X = new float[][] {
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0},
|
||||
};
|
||||
|
||||
final int numClusters = 2;
|
||||
|
||||
// Here we assume that we won't get stuck into a local optima.
|
||||
// It's fine because we're seeding a random, we won't ever have
|
||||
// unstable results but in real life we need multiple initialization
|
||||
// and score comparison
|
||||
List<KMeans.Mean> means = mKMeans.predict(numClusters, X);
|
||||
|
||||
assertEquals("Expected number of clusters is invalid", numClusters, means.size());
|
||||
|
||||
boolean exists1 = false, exists2 = false;
|
||||
for (KMeans.Mean mean : means) {
|
||||
if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) {
|
||||
exists1 = true;
|
||||
} else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) {
|
||||
exists2 = true;
|
||||
} else {
|
||||
throw new AssertionError("Unexpected mean: " + mean);
|
||||
}
|
||||
}
|
||||
assertTrue("Expected means were not predicted, got: " + means,
|
||||
exists1 && exists2);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user