Merge "Rewrite Icons from the TCS." into rvc-dev am: 1cf3ce8fde

Change-Id: I303a573a48661c9d222def2b46501ad1d81ab23c
This commit is contained in:
Abodunrinwa Toki
2020-05-01 11:02:12 +00:00
committed by Automerger Merge Worker
8 changed files with 298 additions and 29 deletions

View File

@@ -161,4 +161,4 @@ public final class RemoteAction implements Parcelable {
return new RemoteAction[size];
}
};
}
}

View File

@@ -424,6 +424,11 @@ public abstract class TextClassifierService extends Service {
return bundle.getParcelable(KEY_RESULT);
}
/** @hide **/
public static <T extends Parcelable> void putResponse(Bundle bundle, T response) {
bundle.putParcelable(KEY_RESULT, response);
}
/**
* Callbacks for TextClassifierService results.
*

View File

@@ -206,6 +206,15 @@ public final class ConversationAction implements Parcelable {
return mExtras;
}
/** @hide */
public Builder toBuilder() {
return new Builder(mType)
.setTextReply(mTextReply)
.setAction(mAction)
.setConfidenceScore(mScore)
.setExtras(mExtras);
}
/** Builder class to construct {@link ConversationAction}. */
public static final class Builder {
@Nullable

View File

@@ -88,6 +88,10 @@ final class EntityConfidence implements Parcelable {
return 0;
}
public Map<String, Float> toMap() {
return new ArrayMap(mEntityConfidence);
}
@Override
public String toString() {
return mEntityConfidence.toString();

View File

@@ -48,6 +48,7 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
@@ -270,6 +271,20 @@ public final class TextClassification implements Parcelable {
return mExtras;
}
/** @hide */
public Builder toBuilder() {
return new Builder()
.setId(mId)
.setText(mText)
.addActions(mActions)
.setEntityConfidence(mEntityConfidence)
.setIcon(mLegacyIcon)
.setLabel(mLegacyLabel)
.setIntent(mLegacyIntent)
.setOnClickListener(mLegacyOnClickListener)
.setExtras(mExtras);
}
@Override
public String toString() {
return String.format(Locale.US,
@@ -323,7 +338,7 @@ public final class TextClassification implements Parcelable {
*/
public static final class Builder {
@NonNull private List<RemoteAction> mActions = new ArrayList<>();
@NonNull private final List<RemoteAction> mActions = new ArrayList<>();
@NonNull private final Map<String, Float> mTypeScoreMap = new ArrayMap<>();
@Nullable private String mText;
@Nullable private Drawable mLegacyIcon;
@@ -332,8 +347,6 @@ public final class TextClassification implements Parcelable {
@Nullable private OnClickListener mLegacyOnClickListener;
@Nullable private String mId;
@Nullable private Bundle mExtras;
@NonNull private final ArrayList<Intent> mActionIntents = new ArrayList<>();
@Nullable private Bundle mForeignLanguageExtra;
/**
* Sets the classified text.
@@ -361,6 +374,18 @@ public final class TextClassification implements Parcelable {
return this;
}
Builder setEntityConfidence(EntityConfidence scores) {
mTypeScoreMap.clear();
mTypeScoreMap.putAll(scores.toMap());
return this;
}
/** @hide */
public Builder clearEntityTypes() {
mTypeScoreMap.clear();
return this;
}
/**
* Adds an action that may be performed on the classified text. Actions should be added in
* order of likelihood that the user will use them, with the most likely action being added
@@ -368,19 +393,21 @@ public final class TextClassification implements Parcelable {
*/
@NonNull
public Builder addAction(@NonNull RemoteAction action) {
return addAction(action, null);
}
/**
* @param intent the intent in the remote action.
* @see #addAction(RemoteAction)
* @hide
*/
@VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
public Builder addAction(RemoteAction action, @Nullable Intent intent) {
Preconditions.checkArgument(action != null);
mActions.add(action);
mActionIntents.add(intent);
return this;
}
/** @hide */
public Builder addActions(Collection<RemoteAction> actions) {
Objects.requireNonNull(actions);
mActions.addAll(actions);
return this;
}
/** @hide */
public Builder clearActions() {
mActions.clear();
return this;
}
@@ -465,16 +492,6 @@ public final class TextClassification implements Parcelable {
return this;
}
/**
* @see #setExtras(Bundle)
* @hide
*/
@VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
public Builder setForeignLanguageExtra(@Nullable Bundle extra) {
mForeignLanguageExtra = extra;
return this;
}
/**
* Builds and returns a {@link TextClassification} object.
*/

View File

@@ -0,0 +1,73 @@
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.view.textclassifier;
import static com.google.common.truth.Truth.assertThat;
import android.app.PendingIntent;
import android.app.RemoteAction;
import android.content.Context;
import android.content.Intent;
import android.graphics.drawable.Icon;
import android.os.Bundle;
import androidx.test.InstrumentationRegistry;
import androidx.test.filters.SmallTest;
import androidx.test.runner.AndroidJUnit4;
import org.junit.Test;
import org.junit.runner.RunWith;
@SmallTest
@RunWith(AndroidJUnit4.class)
public final class ConversationActionTest {
@Test
public void toBuilder() {
final Context context = InstrumentationRegistry.getTargetContext();
final PendingIntent intent = PendingIntent.getActivity(context, 0, new Intent(), 0);
final Icon icon = Icon.createWithData(new byte[]{0}, 0, 1);
final Bundle extras = new Bundle();
extras.putInt("key", 5);
final ConversationAction convAction =
new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
.setAction(new RemoteAction(icon, "title", "descr", intent))
.setConfidenceScore(0.5f)
.setExtras(extras)
.build();
final ConversationAction fromBuilder = convAction.toBuilder().build();
assertThat(fromBuilder.getType()).isEqualTo(convAction.getType());
assertThat(fromBuilder.getAction()).isEqualTo(convAction.getAction());
assertThat(fromBuilder.getConfidenceScore()).isEqualTo(convAction.getConfidenceScore());
assertThat(fromBuilder.getExtras()).isEqualTo(convAction.getExtras());
assertThat(fromBuilder.getTextReply()).isEqualTo(convAction.getTextReply());
}
@Test
public void toBuilder_textReply() {
final ConversationAction convAction =
new ConversationAction.Builder(ConversationAction.TYPE_TEXT_REPLY)
.setTextReply(":P")
.build();
final ConversationAction fromBuilder = convAction.toBuilder().build();
assertThat(fromBuilder.getTextReply()).isEqualTo(convAction.getTextReply());
}
}

View File

@@ -57,6 +57,7 @@ public class TextClassificationTest {
static {
BUNDLE.putString(BUNDLE_KEY, BUNDLE_VALUE);
}
private static final float EPSILON = 1e-7f;
public Icon generateTestIcon(int width, int height, int colorValue) {
final int numPixels = width * height;
@@ -128,8 +129,8 @@ public class TextClassificationTest {
assertEquals(2, result.getEntityCount());
assertEquals(TextClassifier.TYPE_PHONE, result.getEntity(0));
assertEquals(TextClassifier.TYPE_ADDRESS, result.getEntity(1));
assertEquals(0.7f, result.getConfidenceScore(TextClassifier.TYPE_PHONE), 1e-7f);
assertEquals(0.3f, result.getConfidenceScore(TextClassifier.TYPE_ADDRESS), 1e-7f);
assertEquals(0.7f, result.getConfidenceScore(TextClassifier.TYPE_PHONE), EPSILON);
assertEquals(0.3f, result.getConfidenceScore(TextClassifier.TYPE_ADDRESS), EPSILON);
// Extras
assertEquals(BUNDLE_VALUE, result.getExtras().getString(BUNDLE_KEY));
@@ -226,4 +227,45 @@ public class TextClassificationTest {
assertEquals(1, resultSystemTcMetadata.getUserId());
assertFalse(resultSystemTcMetadata.useDefaultTextClassifier());
}
@Test
public void testToBuilder() {
final Context context = InstrumentationRegistry.getInstrumentation().getContext();
final Icon icon1 = generateTestIcon(5, 5, Color.RED);
final Icon icon2 = generateTestIcon(2, 10, Color.BLUE);
final TextClassification classification = new TextClassification.Builder()
.setIcon(icon1.loadDrawable(context))
.setLabel("label")
.setIntent(new Intent("action"))
.setOnClickListener(view -> { })
.addAction(new RemoteAction(icon1, "title1", "desc1",
PendingIntent.getActivity(context, 0, new Intent("action1"), 0)))
.addAction(new RemoteAction(icon1, "title2", "desc2",
PendingIntent.getActivity(context, 0, new Intent("action2"), 0)))
.setEntityType(TextClassifier.TYPE_EMAIL, 0.5f)
.setEntityType(TextClassifier.TYPE_PHONE, 0.4f)
.build();
final TextClassification fromBuilder = classification.toBuilder().build();
assertEquals(classification.getId(), fromBuilder.getId());
assertEquals(classification.getText(), fromBuilder.getText());
assertEquals(classification.getIcon(), fromBuilder.getIcon());
assertEquals(classification.getLabel(), fromBuilder.getLabel());
assertEquals(classification.getIntent(), fromBuilder.getIntent());
assertEquals(classification.getOnClickListener(), fromBuilder.getOnClickListener());
assertEquals(classification.getExtras(), fromBuilder.getExtras());
assertEquals(classification.getActions(), fromBuilder.getActions());
assertEquals(classification.getEntityCount(), fromBuilder.getEntityCount());
assertEquals(classification.getEntity(0), fromBuilder.getEntity(0));
assertEquals(classification.getEntity(1), fromBuilder.getEntity(1));
assertEquals(
classification.getConfidenceScore(TextClassifier.TYPE_EMAIL),
fromBuilder.getConfidenceScore(TextClassifier.TYPE_EMAIL),
EPSILON);
assertEquals(
classification.getConfidenceScore(TextClassifier.TYPE_PHONE),
fromBuilder.getConfidenceScore(TextClassifier.TYPE_PHONE),
EPSILON);
}
}

View File

@@ -19,14 +19,18 @@ package com.android.server.textclassifier;
import android.annotation.NonNull;
import android.annotation.Nullable;
import android.annotation.UserIdInt;
import android.app.RemoteAction;
import android.content.ComponentName;
import android.content.Context;
import android.content.Intent;
import android.content.ServiceConnection;
import android.content.pm.PackageManager;
import android.graphics.drawable.Icon;
import android.net.Uri;
import android.os.Binder;
import android.os.Bundle;
import android.os.IBinder;
import android.os.Parcelable;
import android.os.Process;
import android.os.RemoteException;
import android.os.UserHandle;
@@ -39,6 +43,7 @@ import android.text.TextUtils;
import android.util.ArrayMap;
import android.util.Slog;
import android.util.SparseArray;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.SystemTextClassifierMetadata;
@@ -69,6 +74,7 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Queue;
import java.util.stream.Collectors;
/**
* A manager for TextClassifier services.
@@ -203,7 +209,7 @@ public final class TextClassificationManagerService extends ITextClassifierServi
request.getSystemTextClassifierMetadata(),
/* verifyCallingPackage= */ true,
/* attemptToBind= */ true,
service -> service.onClassifyText(sessionId, request, callback),
service -> service.onClassifyText(sessionId, request, wrap(callback)),
"onClassifyText",
callback);
}
@@ -289,7 +295,8 @@ public final class TextClassificationManagerService extends ITextClassifierServi
request.getSystemTextClassifierMetadata(),
/* verifyCallingPackage= */ true,
/* attemptToBind= */ true,
service -> service.onSuggestConversationActions(sessionId, request, callback),
service -> service.onSuggestConversationActions(
sessionId, request, wrap(callback)),
"onSuggestConversationActions",
callback);
}
@@ -464,6 +471,10 @@ public final class TextClassificationManagerService extends ITextClassifierServi
}
}
private static ITextClassifierCallback wrap(ITextClassifierCallback orig) {
return new CallbackWrapper(orig);
}
private void onTextClassifierServicePackageOverrideChanged(String overriddenPackage) {
synchronized (mLock) {
final int size = mUserStates.size();
@@ -1004,4 +1015,112 @@ public final class TextClassificationManagerService extends ITextClassifierServi
onTextClassifierServicePackageOverrideChanged(currentServicePackageOverride);
}
}
/**
* Wraps an ITextClassifierCallback and modifies the response to it where necessary.
*/
private static final class CallbackWrapper extends ITextClassifierCallback.Stub {
private final ITextClassifierCallback mWrapped;
CallbackWrapper(ITextClassifierCallback wrapped) {
mWrapped = Objects.requireNonNull(wrapped);
}
@Override
public void onSuccess(Bundle result) {
final Parcelable parcelled = TextClassifierService.getResponse(result);
if (parcelled instanceof TextClassification) {
rewriteTextClassificationIcons(result);
} else if (parcelled instanceof ConversationActions) {
rewriteConversationActionsIcons(result);
} else {
// do nothing.
}
try {
mWrapped.onSuccess(result);
} catch (RemoteException e) {
Slog.e(LOG_TAG, "Callback error", e);
}
}
private static void rewriteTextClassificationIcons(Bundle result) {
final TextClassification classification = TextClassifierService.getResponse(result);
boolean rewrite = false;
for (RemoteAction action : classification.getActions()) {
rewrite |= shouldRewriteIcon(action);
}
if (rewrite) {
TextClassifierService.putResponse(
result,
classification.toBuilder()
.clearActions()
.addActions(classification.getActions()
.stream()
.map(action -> validAction(action))
.collect(Collectors.toList()))
.build());
}
}
private static void rewriteConversationActionsIcons(Bundle result) {
final ConversationActions convActions = TextClassifierService.getResponse(result);
boolean rewrite = false;
for (ConversationAction convAction : convActions.getConversationActions()) {
rewrite |= shouldRewriteIcon(convAction.getAction());
}
if (rewrite) {
TextClassifierService.putResponse(
result,
new ConversationActions(
convActions.getConversationActions()
.stream()
.map(convAction -> convAction.toBuilder()
.setAction(validAction(convAction.getAction()))
.build())
.collect(Collectors.toList()),
convActions.getId()));
}
}
@Nullable
private static RemoteAction validAction(@Nullable RemoteAction action) {
if (!shouldRewriteIcon(action)) {
return action;
}
final RemoteAction newAction = new RemoteAction(
changeIcon(action.getIcon()),
action.getTitle(),
action.getContentDescription(),
action.getActionIntent());
newAction.setEnabled(action.isEnabled());
newAction.setShouldShowIcon(action.shouldShowIcon());
return newAction;
}
private static boolean shouldRewriteIcon(@Nullable RemoteAction action) {
// Check whether to rewrite the icon.
// Rewrite icons to ensure that the icons do not:
// 1. Leak package names
// 2. are renderable in the client process.
return action != null && action.getIcon().getType() == Icon.TYPE_RESOURCE;
}
/** Changes icon of type=RESOURCES to icon of type=URI. */
private static Icon changeIcon(Icon icon) {
final Uri uri = IconsUriHelper.getInstance()
.getContentUri(icon.getResPackage(), icon.getResId());
return Icon.createWithContentUri(uri);
}
@Override
public void onFailure() {
try {
mWrapped.onFailure();
} catch (RemoteException e) {
Slog.e(LOG_TAG, "Callback error", e);
}
}
}
}