diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java index 323bf597ab550..ed6ec54986ad5 100644 --- a/core/java/android/view/textclassifier/TextClassifierImpl.java +++ b/core/java/android/view/textclassifier/TextClassifierImpl.java @@ -451,10 +451,6 @@ public final class TextClassifierImpl implements TextClassifier { Collection expectedTypes = resolveActionTypesFromRequest(request); List conversationActions = new ArrayList<>(); for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) { - if (request.getMaxSuggestions() >= 0 - && conversationActions.size() == request.getMaxSuggestions()) { - break; - } String actionType = nativeSuggestion.getActionType(); if (!expectedTypes.contains(actionType)) { continue; @@ -484,6 +480,10 @@ public final class TextClassifierImpl implements TextClassifier { } conversationActions = ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions); + if (request.getMaxSuggestions() >= 0 + && conversationActions.size() > request.getMaxSuggestions()) { + conversationActions = conversationActions.subList(0, request.getMaxSuggestions()); + } String resultId = ActionsSuggestionsHelper.createResultId( mContext, request.getConversation(), diff --git a/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java b/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java index 433991e86212d..79512a744f52a 100644 --- a/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java +++ b/core/tests/coretests/src/android/view/textclassifier/TextClassifierTest.java @@ -380,7 +380,7 @@ public class TextClassifierTest { } @Test - public void testSuggestConversationActions_textReplyOnly_maxThree() { + public void testSuggestConversationActions_textReplyOnly_maxOne() { if (isTextClassifierDisabled()) return; ConversationActions.Message message = new ConversationActions.Message.Builder( @@ -399,12 +399,11 @@ public class TextClassifierTest { .build(); ConversationActions conversationActions = mClassifier.suggestConversationActions(request); - assertTrue(conversationActions.getConversationActions().size() > 0); - for (ConversationAction conversationAction : - conversationActions.getConversationActions()) { - assertThat(conversationAction, - isConversationAction(ConversationAction.TYPE_TEXT_REPLY)); - } + Truth.assertThat(conversationActions.getConversationActions()).hasSize(1); + ConversationAction conversationAction = conversationActions.getConversationActions().get(0); + Truth.assertThat(conversationAction.getType()).isEqualTo( + ConversationAction.TYPE_TEXT_REPLY); + Truth.assertThat(conversationAction.getTextReply()).isNotNull(); } @Test @@ -493,6 +492,24 @@ public class TextClassifierTest { ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty(); } + @Test + public void testSuggetsConversationActions_deduplicate() { + if (isTextClassifierDisabled()) return; + ConversationActions.Message message = + new ConversationActions.Message.Builder( + ConversationActions.Message.PERSON_USER_OTHERS) + .setText("a@android.com b@android.com") + .build(); + ConversationActions.Request request = + new ConversationActions.Request.Builder(Collections.singletonList(message)) + .setMaxSuggestions(3) + .build(); + + ConversationActions conversationActions = mClassifier.suggestConversationActions(request); + + Truth.assertThat(conversationActions.getConversationActions()).isEmpty(); + } + private boolean isTextClassifierDisabled() { return mClassifier == null || mClassifier == TextClassifier.NO_OP; }