Merge "K-Means color clustering" into oc-dr1-dev

This commit is contained in:
Lucas Dupin
2017-06-17 17:20:59 +00:00
committed by Android (Google) Code Review
10 changed files with 689 additions and 14 deletions

View File

@@ -27,6 +27,7 @@ import android.os.Parcelable;
import android.util.Size;
import com.android.internal.graphics.palette.Palette;
import com.android.internal.graphics.palette.VariationalKMeansQuantizer;
import java.util.ArrayList;
import java.util.Collections;
@@ -142,6 +143,8 @@ public final class WallpaperColors implements Parcelable {
final Palette palette = Palette
.from(bitmap)
.setQuantizer(new VariationalKMeansQuantizer())
.maximumColorCount(5)
.clearFilters()
.resizeBitmapArea(MAX_WALLPAPER_EXTRACTION_AREA)
.generate();

View File

@@ -61,7 +61,7 @@ import com.android.internal.graphics.palette.Palette.Swatch;
* This means that the color space is divided into distinct colors, rather than representative
* colors.
*/
final class ColorCutQuantizer {
final class ColorCutQuantizer implements Quantizer {
private static final String LOG_TAG = "ColorCutQuantizer";
private static final boolean LOG_TIMINGS = false;
@@ -73,22 +73,22 @@ final class ColorCutQuantizer {
private static final int QUANTIZE_WORD_WIDTH = 5;
private static final int QUANTIZE_WORD_MASK = (1 << QUANTIZE_WORD_WIDTH) - 1;
final int[] mColors;
final int[] mHistogram;
final List<Swatch> mQuantizedColors;
final TimingLogger mTimingLogger;
final Palette.Filter[] mFilters;
int[] mColors;
int[] mHistogram;
List<Swatch> mQuantizedColors;
TimingLogger mTimingLogger;
Palette.Filter[] mFilters;
private final float[] mTempHsl = new float[3];
/**
* Constructor.
* Execute color quantization.
*
* @param pixels histogram representing an image's pixel data
* @param maxColors The maximum number of colors that should be in the result palette.
* @param filters Set of filters to use in the quantization stage
*/
ColorCutQuantizer(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
public void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters) {
mTimingLogger = LOG_TIMINGS ? new TimingLogger(LOG_TAG, "Creation") : null;
mFilters = filters;
@@ -160,7 +160,7 @@ final class ColorCutQuantizer {
/**
* @return the list of quantized colors
*/
List<Swatch> getQuantizedColors() {
public List<Swatch> getQuantizedColors() {
return mQuantizedColors;
}

View File

@@ -613,6 +613,8 @@ public final class Palette {
private final List<Palette.Filter> mFilters = new ArrayList<>();
private Rect mRegion;
private Quantizer mQuantizer;
/**
* Construct a new {@link Palette.Builder} using a source {@link Bitmap}
*/
@@ -725,6 +727,18 @@ public final class Palette {
return this;
}
/**
* Set a specific quantization algorithm. {@link ColorCutQuantizer} will
* be used if unspecified.
*
* @param quantizer Quantizer implementation.
*/
@NonNull
public Palette.Builder setQuantizer(Quantizer quantizer) {
mQuantizer = quantizer;
return this;
}
/**
* Set a region of the bitmap to be used exclusively when calculating the palette.
* <p>This only works when the original input is a {@link Bitmap}.</p>
@@ -818,17 +832,19 @@ public final class Palette {
}
// Now generate a quantizer from the Bitmap
final ColorCutQuantizer quantizer = new ColorCutQuantizer(
getPixelsFromBitmap(bitmap),
mMaxColors,
mFilters.isEmpty() ? null : mFilters.toArray(new Palette.Filter[mFilters.size()]));
if (mQuantizer == null) {
mQuantizer = new ColorCutQuantizer();
}
mQuantizer.quantize(getPixelsFromBitmap(bitmap),
mMaxColors, mFilters.isEmpty() ? null :
mFilters.toArray(new Palette.Filter[mFilters.size()]));
// If created a new bitmap, recycle it
if (bitmap != mBitmap) {
bitmap.recycle();
}
swatches = quantizer.getQuantizedColors();
swatches = mQuantizer.getQuantizedColors();
if (logger != null) {
logger.addSplit("Color quantization completed");

View File

@@ -0,0 +1,27 @@
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
package com.android.internal.graphics.palette;
import java.util.List;
/**
* Definition of an algorithm that receives pixels and outputs a list of colors.
*/
public interface Quantizer {
void quantize(final int[] pixels, final int maxColors, final Palette.Filter[] filters);
List<Palette.Swatch> getQuantizedColors();
}

View File

@@ -0,0 +1,154 @@
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
package com.android.internal.graphics.palette;
import android.util.Log;
import com.android.internal.graphics.ColorUtils;
import com.android.internal.ml.clustering.KMeans;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* A quantizer that uses k-means
*/
public class VariationalKMeansQuantizer implements Quantizer {
private static final String TAG = "KMeansQuantizer";
private static final boolean DEBUG = false;
/**
* Clusters closer than this value will me merged.
*/
private final float mMinClusterSqDistance;
/**
* K-means can get stuck in local optima, this can be avoided by
* repeating it and getting the "best" execution.
*/
private final int mInitializations;
/**
* Initialize KMeans with a fixed random state to have
* consistent results across multiple runs.
*/
private final KMeans mKMeans = new KMeans(new Random(0), 30, 0);
private List<Palette.Swatch> mQuantizedColors;
public VariationalKMeansQuantizer() {
this(0.25f /* cluster distance */);
}
public VariationalKMeansQuantizer(float minClusterDistance) {
this(minClusterDistance, 1 /* initializations */);
}
public VariationalKMeansQuantizer(float minClusterDistance, int initializations) {
mMinClusterSqDistance = minClusterDistance * minClusterDistance;
mInitializations = initializations;
}
/**
* K-Means quantizer.
*
* @param pixels Pixels to quantize.
* @param maxColors Maximum number of clusters to extract.
* @param filters Colors that should be ignored
*/
@Override
public void quantize(int[] pixels, int maxColors, Palette.Filter[] filters) {
// Start by converting all colors to HSL.
// HLS is way more meaningful for clustering than RGB.
final float[] hsl = {0, 0, 0};
final float[][] hslPixels = new float[pixels.length][3];
for (int i = 0; i < pixels.length; i++) {
ColorUtils.colorToHSL(pixels[i], hsl);
// Normalize hue so all values go from 0 to 1.
hslPixels[i][0] = hsl[0] / 360f;
hslPixels[i][1] = hsl[1];
hslPixels[i][2] = hsl[2];
}
final List<KMeans.Mean> optimalMeans = getOptimalKMeans(maxColors, hslPixels);
// Ideally we should run k-means again to merge clusters but it would be too expensive,
// instead we just merge all clusters that are closer than a threshold.
for (int i = 0; i < optimalMeans.size(); i++) {
KMeans.Mean current = optimalMeans.get(i);
float[] currentCentroid = current.getCentroid();
for (int j = i + 1; j < optimalMeans.size(); j++) {
KMeans.Mean compareTo = optimalMeans.get(j);
float[] compareToCentroid = compareTo.getCentroid();
float sqDistance = KMeans.sqDistance(currentCentroid, compareToCentroid);
// Merge them
if (sqDistance < mMinClusterSqDistance) {
optimalMeans.remove(compareTo);
current.getItems().addAll(compareTo.getItems());
for (int k = 0; k < currentCentroid.length; k++) {
currentCentroid[k] += (compareToCentroid[k] - currentCentroid[k]) / 2.0;
}
j--;
}
}
}
// Convert data to final format, de-normalizing the hue.
mQuantizedColors = new ArrayList<>();
for (KMeans.Mean mean : optimalMeans) {
if (mean.getItems().size() == 0) {
continue;
}
float[] centroid = mean.getCentroid();
mQuantizedColors.add(new Palette.Swatch(new float[]{
centroid[0] * 360f,
centroid[1],
centroid[2]
}, mean.getItems().size()));
}
}
private List<KMeans.Mean> getOptimalKMeans(int k, float[][] inputData) {
List<KMeans.Mean> optimal = null;
double optimalScore = -Double.MAX_VALUE;
int runs = mInitializations;
while (runs > 0) {
if (DEBUG) {
Log.d(TAG, "k-means run: " + runs);
}
List<KMeans.Mean> means = mKMeans.predict(k, inputData);
double score = KMeans.score(means);
if (optimal == null || score > optimalScore) {
if (DEBUG) {
Log.d(TAG, "\tnew optimal score: " + score);
}
optimalScore = score;
optimal = means;
}
runs--;
}
return optimal;
}
@Override
public List<Palette.Swatch> getQuantizedColors() {
return mQuantizedColors;
}
}

View File

@@ -0,0 +1,243 @@
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License
*/
package com.android.internal.ml.clustering;
import android.annotation.NonNull;
import android.util.Log;
import com.android.internal.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/**
* Simple K-Means implementation
*/
public class KMeans {
private static final boolean DEBUG = false;
private static final String TAG = "KMeans";
private final Random mRandomState;
private final int mMaxIterations;
private float mSqConvergenceEpsilon;
public KMeans() {
this(new Random());
}
public KMeans(Random random) {
this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
}
public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
mRandomState = random;
mMaxIterations = maxIterations;
mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
}
/**
* Runs k-means on the input data (X) trying to find k means.
*
* K-Means is known for getting stuck into local optima, so you might
* want to run it multiple time and argmax on {@link KMeans#score(List)}
*
* @param k The number of points to return.
* @param inputData Input data.
* @return An array of k Means, each representing a centroid and data points that belong to it.
*/
public List<Mean> predict(final int k, final float[][] inputData) {
checkDataSetSanity(inputData);
int dimension = inputData[0].length;
final ArrayList<Mean> means = new ArrayList<>();
for (int i = 0; i < k; i++) {
Mean m = new Mean(dimension);
for (int j = 0; j < dimension; j++) {
m.mCentroid[j] = mRandomState.nextFloat();
}
means.add(m);
}
// Iterate until we converge or run out of iterations
boolean converged = false;
for (int i = 0; i < mMaxIterations; i++) {
converged = step(means, inputData);
if (converged) {
if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
break;
}
}
if (!converged && DEBUG) Log.d(TAG, "Did not converge");
return means;
}
/**
* Score calculates the inertia between means.
* This can be considered as an E step of an EM algorithm.
*
* @param means Means to use when calculating score.
* @return The score
*/
public static double score(@NonNull List<Mean> means) {
double score = 0;
final int meansSize = means.size();
for (int i = 0; i < meansSize; i++) {
Mean mean = means.get(i);
for (int j = 0; j < meansSize; j++) {
Mean compareTo = means.get(j);
if (mean == compareTo) {
continue;
}
double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
score += distance;
}
}
return score;
}
@VisibleForTesting
public void checkDataSetSanity(float[][] inputData) {
if (inputData == null) {
throw new IllegalArgumentException("Data set is null.");
} else if (inputData.length == 0) {
throw new IllegalArgumentException("Data set is empty.");
} else if (inputData[0] == null) {
throw new IllegalArgumentException("Bad data set format.");
}
final int dimension = inputData[0].length;
final int length = inputData.length;
for (int i = 1; i < length; i++) {
if (inputData[i] == null || inputData[i].length != dimension) {
throw new IllegalArgumentException("Bad data set format.");
}
}
}
/**
* K-Means iteration.
*
* @param means Current means
* @param inputData Input data
* @return True if data set converged
*/
private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
// Clean up the previous state because we need to compute
// which point belongs to each mean again.
for (int i = means.size() - 1; i >= 0; i--) {
final Mean mean = means.get(i);
mean.mClosestItems.clear();
}
for (int i = inputData.length - 1; i >= 0; i--) {
final float[] current = inputData[i];
final Mean nearest = nearestMean(current, means);
nearest.mClosestItems.add(current);
}
boolean converged = true;
// Move each mean towards the nearest data set points
for (int i = means.size() - 1; i >= 0; i--) {
final Mean mean = means.get(i);
if (mean.mClosestItems.size() == 0) {
continue;
}
// Compute the new mean centroid:
// 1. Sum all all points
// 2. Average them
final float[] oldCentroid = mean.mCentroid;
mean.mCentroid = new float[oldCentroid.length];
for (int j = 0; j < mean.mClosestItems.size(); j++) {
// Update each centroid component
for (int p = 0; p < mean.mCentroid.length; p++) {
mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
}
}
for (int j = 0; j < mean.mCentroid.length; j++) {
mean.mCentroid[j] /= mean.mClosestItems.size();
}
// We converged if the centroid didn't move for any of the means.
if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
converged = false;
}
}
return converged;
}
@VisibleForTesting
public static Mean nearestMean(float[] point, List<Mean> means) {
Mean nearest = null;
float nearestDistance = Float.MAX_VALUE;
final int meanCount = means.size();
for (int i = 0; i < meanCount; i++) {
Mean next = means.get(i);
// We don't need the sqrt when comparing distances in euclidean space
// because they exist on both sides of the equation and cancel each other out.
float nextDistance = sqDistance(point, next.mCentroid);
if (nextDistance < nearestDistance) {
nearest = next;
nearestDistance = nextDistance;
}
}
return nearest;
}
@VisibleForTesting
public static float sqDistance(float[] a, float[] b) {
float dist = 0;
final int length = a.length;
for (int i = 0; i < length; i++) {
dist += (a[i] - b[i]) * (a[i] - b[i]);
}
return dist;
}
/**
* Definition of a mean, contains a centroid and points on its cluster.
*/
public static class Mean {
float[] mCentroid;
final ArrayList<float[]> mClosestItems = new ArrayList<>();
public Mean(int dimension) {
mCentroid = new float[dimension];
}
public Mean(float ...centroid) {
mCentroid = centroid;
}
public float[] getCentroid() {
return mCentroid;
}
public List<float[]> getItems() {
return mClosestItems;
}
@Override
public String toString() {
return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
+ mClosestItems.size() + ")";
}
}
}

20
tests/Internal/Android.mk Normal file
View File

@@ -0,0 +1,20 @@
LOCAL_PATH:= $(call my-dir)
include $(CLEAR_VARS)
LOCAL_USE_AAPT2 := true
LOCAL_MODULE_TAGS := tests
LOCAL_PROTOC_OPTIMIZE_TYPE := nano
# Include some source files directly to be able to access package members
LOCAL_SRC_FILES := $(call all-java-files-under, src)
LOCAL_JAVA_LIBRARIES := android.test.runner
LOCAL_STATIC_JAVA_LIBRARIES := junit legacy-android-test android-support-test
LOCAL_CERTIFICATE := platform
LOCAL_PACKAGE_NAME := InternalTests
LOCAL_COMPATIBILITY_SUITE := device-tests
include $(BUILD_PACKAGE)

View File

@@ -0,0 +1,28 @@
<?xml version="1.0" encoding="utf-8"?>
<!--
~ Copyright (C) 2017 The Android Open Source Project
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.internal.tests">
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
<application>
<uses-library android:name="android.test.runner" />
</application>
<instrumentation android:name="android.support.test.runner.AndroidJUnitRunner"
android:targetPackage="com.android.internal.tests"
android:label="Internal Tests" />
</manifest>

View File

@@ -0,0 +1,29 @@
<?xml version="1.0" encoding="utf-8"?>
<!--
~ Copyright (C) 2017 The Android Open Source Project
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License
-->
<configuration description="Runs tests for internal classes/utilities.">
<target_preparer class="com.android.tradefed.targetprep.TestAppInstallSetup">
<option name="test-file-name" value="InternalTests.apk" />
</target_preparer>
<option name="test-suite-tag" value="apct" />
<option name="test-suite-tag" value="framework-base-presubmit" />
<option name="test-tag" value="InternalTests" />
<test class="com.android.tradefed.testtype.AndroidJUnitTest" >
<option name="package" value="com.android.internal.tests" />
<option name="runner" value="android.support.test.runner.AndroidJUnitRunner" />
</test>
</configuration>

View File

@@ -0,0 +1,155 @@
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.android.internal.ml.clustering;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import android.annotation.SuppressLint;
import android.support.test.filters.SmallTest;
import android.support.test.runner.AndroidJUnit4;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class KMeansTest {
// Error tolerance (epsilon)
private static final double EPS = 0.01;
private KMeans mKMeans;
@Before
public void setUp() {
// Setup with a random seed to have predictable results
mKMeans = new KMeans(new Random(0), 30, 0);
}
@Test
public void getCheckDataSanityTest() {
try {
mKMeans.checkDataSetSanity(new float[][] {
{0, 1, 2},
{1, 2, 3}
});
} catch (IllegalArgumentException e) {
Assert.fail("Valid data didn't pass sanity check");
}
try {
mKMeans.checkDataSetSanity(new float[][] {
null,
{1, 2, 3}
});
Assert.fail("Data has null items and passed");
} catch (IllegalArgumentException e) {}
try {
mKMeans.checkDataSetSanity(new float[][] {
{0, 1, 2, 4},
{1, 2, 3}
});
Assert.fail("Data has invalid shape and passed");
} catch (IllegalArgumentException e) {}
try {
mKMeans.checkDataSetSanity(null);
Assert.fail("Null data should throw exception");
} catch (IllegalArgumentException e) {}
}
@Test
public void sqDistanceTest() {
float a[] = {4, 10};
float b[] = {5, 2};
float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2));
assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS);
}
@Test
public void nearestMeanTest() {
KMeans.Mean meanA = new KMeans.Mean(0, 1);
KMeans.Mean meanB = new KMeans.Mean(1, 1);
List<KMeans.Mean> means = Arrays.asList(meanA, meanB);
KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means);
assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB);
}
@SuppressLint("DefaultLocale")
@Test
public void scoreTest() {
List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f),
new KMeans.Mean(0, 0.1f, 0.15f),
new KMeans.Mean(0.1f, 0.2f, 0.1f));
List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0),
new KMeans.Mean(0, 0.5f, 0.5f),
new KMeans.Mean(1, 0.9f, 0.9f));
double closeScore = KMeans.score(closeMeans);
double farScore = KMeans.score(farMeans);
assertTrue(String.format("Score of well distributed means should be greater than "
+ "close means but got: %f, %f", farScore, closeScore), farScore > closeScore);
}
@Test
public void predictTest() {
float[] expectedCentroid1 = {1, 1, 1};
float[] expectedCentroid2 = {0, 0, 0};
float[][] X = new float[][] {
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
};
final int numClusters = 2;
// Here we assume that we won't get stuck into a local optima.
// It's fine because we're seeding a random, we won't ever have
// unstable results but in real life we need multiple initialization
// and score comparison
List<KMeans.Mean> means = mKMeans.predict(numClusters, X);
assertEquals("Expected number of clusters is invalid", numClusters, means.size());
boolean exists1 = false, exists2 = false;
for (KMeans.Mean mean : means) {
if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) {
exists1 = true;
} else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) {
exists2 = true;
} else {
throw new AssertionError("Unexpected mean: " + mean);
}
}
assertTrue("Expected means were not predicted, got: " + means,
exists1 && exists2);
}
}