feat(experiment): mlkit smart reply

Signed-off-by: Infi <infi@infi.sh>
This commit is contained in:
Infi 2024-10-29 02:26:44 +01:00
parent 4c2bf8703e
commit c0af196da7
12 changed files with 303 additions and 1 deletions

View File

@ -292,6 +292,9 @@ dependencies {
// Shimmer - loading animations
implementation "com.valentinilk.shimmer:compose-shimmer:1.3.1"
// MLKit - Machine Learning
implementation 'com.google.android.gms:play-services-mlkit-smart-reply:16.0.0-beta1'
}
sqldelight {

View File

@ -50,6 +50,10 @@
android:name="io.sentry.auto-init"
android:value="false" />
<meta-data
android:name="com.google.mlkit.vision.DEPENDENCIES"
android:value="smart_reply" />
<meta-data
android:name="com.google.firebase.messaging.default_notification_icon"
android:resource="@drawable/ic_notification_monochrome" />

View File

@ -27,6 +27,7 @@ class ExperimentInstance(default: Boolean) {
*/
object Experiments {
val useKotlinBasedMarkdownRenderer = ExperimentInstance(false)
val useMlKitSmartReplyInApp = ExperimentInstance(false)
suspend fun hydrateWithKv() {
val kvStorage = KVStorage(RevoltApplication.instance)
@ -40,5 +41,8 @@ object Experiments {
useKotlinBasedMarkdownRenderer.setEnabled(
kvStorage.getBoolean("exp/useKotlinBasedMarkdownRenderer") ?: false
)
useMlKitSmartReplyInApp.setEnabled(
kvStorage.getBoolean("exp/useMlKitSmartReplyInApp") ?: false
)
}
}

View File

@ -0,0 +1,66 @@
package chat.revolt.api.settings.experiments
import android.util.Log
import chat.revolt.api.RevoltAPI
import chat.revolt.api.internals.ULID
import chat.revolt.api.schemas.Message
import chat.revolt.api.schemas.RsResult
import com.google.mlkit.nl.smartreply.SmartReply
import com.google.mlkit.nl.smartreply.TextMessage
object SmartReplyImpl {
val client = SmartReply.getClient()
fun forMessages(
messages: List<Message>,
onResult: (RsResult<List<String>, Exception>) -> Unit
) {
if (messages.size > 10) {
onResult(RsResult.err(IllegalArgumentException("Too many messages")))
return
}
Log.d("SmartReplyImpl", "Creating conversation for ${messages.size} messages")
val conversation = messages.map {
if (it.author == RevoltAPI.selfId) {
if (it.id == null) {
Log.w("SmartReplyImpl", "Message ID is null in local user message, skipping")
return@map null
}
if (it.content?.isEmpty() == true || it.content == null) {
Log.w(
"SmartReplyImpl",
"Message content is null or empty in local user message, skipping"
)
return@map null
}
TextMessage.createForLocalUser(it.content, ULID.asTimestamp(it.id))
} else {
if (it.id == null || it.author == null) {
Log.w(
"SmartReplyImpl",
"Message ID or author is null in remote user message, skipping"
)
return@map null
}
if (it.content?.isEmpty() == true || it.content == null) {
Log.w(
"SmartReplyImpl",
"Message content is null or empty in remote user message, skipping"
)
return@map null
}
TextMessage.createForRemoteUser(it.content, ULID.asTimestamp(it.id), it.author)
}
}.filterNotNull().reversed()
Log.d("SmartReplyImpl", "Suggesting replies for ${conversation.size} messages")
client.suggestReplies(conversation).addOnSuccessListener {
onResult(RsResult.ok(it.suggestions.map { s -> s.text }))
}.addOnFailureListener {
onResult(RsResult.err(it))
}
}
}

View File

@ -4,6 +4,7 @@ import kotlinx.coroutines.flow.MutableSharedFlow
sealed class UiCallback {
data class ReplyToMessage(val messageId: String) : UiCallback()
data class ReplyToMessageWithContent(val messageId: String, val content: String) : UiCallback()
data class EditMessage(val messageId: String) : UiCallback()
}
@ -14,6 +15,10 @@ object UiCallbacks {
uiCallbackFlow.emit(UiCallback.ReplyToMessage(messageId))
}
suspend fun replyToMessageWithContent(messageId: String, content: String) {
uiCallbackFlow.emit(UiCallback.ReplyToMessageWithContent(messageId, content))
}
suspend fun editMessage(messageId: String) {
uiCallbackFlow.emit(UiCallback.EditMessage(messageId))
}

View File

@ -19,6 +19,7 @@ fun SheetButton(
onClick: () -> Unit,
modifier: Modifier = Modifier,
supportingContent: @Composable (() -> Unit)? = null,
trailingContent: @Composable (() -> Unit)? = null,
dangerous: Boolean = false
) {
Box(
@ -62,6 +63,19 @@ fun SheetButton(
this()
}
}
},
trailingContent = {
trailingContent?.run {
CompositionLocalProvider(
value = if (dangerous) {
LocalContentColor provides MaterialTheme.colorScheme.error
} else {
LocalContentColor provides MaterialTheme.colorScheme.onSurface
}
) {
this()
}
}
}
)
}

View File

@ -134,6 +134,7 @@ import kotlinx.coroutines.launch
import kotlinx.datetime.Instant
import java.io.File
import kotlin.math.max
import kotlin.math.min
sealed class ChannelScreenItem {
data class RegularMessage(val message: Message) : ChannelScreenItem()
@ -372,6 +373,22 @@ fun ChannelScreen(
scope.launch {
ActionChannel.send(Action.ReportMessage(messageContextSheetTarget))
}
},
lastTenMessages = {
// First we find the message in viewModel.items
val index = viewModel.items.filterIsInstance<ChannelScreenItem.RegularMessage>()
.indexOfFirst { it.message.id == messageContextSheetTarget }
Log.d("ChannelScreen", "We have index $index")
Log.d("ChannelScreen", "Items.len ${viewModel.items.size}")
// Then we take the last 10 messages before it. We take care to not go out of bounds.
val messages =
viewModel.items.filterIsInstance<ChannelScreenItem.RegularMessage>()
.subList(index, min(index + 5, (viewModel.items.size - 1)))
.map { it.message }
messages
}
)
}

View File

@ -729,6 +729,21 @@ class ChannelScreenViewModel @Inject constructor(
this@ChannelScreenViewModel.draftAttachments.clear()
draftReplyTo.clear()
}
is UiCallback.ReplyToMessageWithContent -> {
val message = items.find { m ->
m is ChannelScreenItem.RegularMessage && m.message.id == it.messageId
} as? ChannelScreenItem.RegularMessage ?: return@onEach
val shouldMention = kvStorage.getBoolean("mentionOnReply") ?: false
draftReplyTo.add(
SendMessageReply(
message.message.id ?: return@onEach,
shouldMention
)
)
putDraftContent(it.content)
}
}
}.catch {
Log.e("ChannelScreen", "Failed to receive UI callback", it)

View File

@ -30,6 +30,7 @@ class ExperimentsSettingsScreenViewModel : ViewModel() {
fun init() {
viewModelScope.launch {
useKotlinMdRendererChecked.value = Experiments.useKotlinBasedMarkdownRenderer.isEnabled
useMlKitSmartReplyInAppChecked.value = Experiments.useMlKitSmartReplyInApp.isEnabled
}
}
@ -42,6 +43,7 @@ class ExperimentsSettingsScreenViewModel : ViewModel() {
}
val useKotlinMdRendererChecked = mutableStateOf(false)
val useMlKitSmartReplyInAppChecked = mutableStateOf(false)
fun setUseKotlinMdRendererChecked(value: Boolean) {
viewModelScope.launch {
@ -50,6 +52,14 @@ class ExperimentsSettingsScreenViewModel : ViewModel() {
useKotlinMdRendererChecked.value = value
}
}
fun setUseMlKitSmartReplyInAppChecked(value: Boolean) {
viewModelScope.launch {
kv.set("exp/useMlKitSmartReplyInApp", value)
Experiments.useMlKitSmartReplyInApp.setEnabled(value)
useMlKitSmartReplyInAppChecked.value = value
}
}
}
@Composable
@ -83,6 +93,22 @@ fun ExperimentsSettingsScreen(
modifier = Modifier.clickable { viewModel.setUseKotlinMdRendererChecked(!viewModel.useKotlinMdRendererChecked.value) }
)
ListItem(
headlineContent = {
Text("Smart Reply Suggestions (In-App)")
},
supportingContent = {
Text("Use a machine learning model to suggest replies to messages from the message sheet.")
},
trailingContent = {
Switch(
checked = viewModel.useMlKitSmartReplyInAppChecked.value,
onCheckedChange = viewModel::setUseMlKitSmartReplyInAppChecked
)
},
modifier = Modifier.clickable { viewModel.setUseMlKitSmartReplyInAppChecked(!viewModel.useMlKitSmartReplyInAppChecked.value) }
)
Subcategory(
title = {
Text("Disable experiments")

View File

@ -0,0 +1,62 @@
package chat.revolt.sheets
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import chat.revolt.R
import chat.revolt.components.generic.SheetButton
@Composable
fun MessageContentMLKitReplySelectSheet(
options: List<String>,
onOptionSelected: (String) -> Unit
) {
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(16.dp),
modifier = Modifier.fillMaxWidth()
) {
Icon(
painter = painterResource(R.drawable.ic_creation_24dp),
contentDescription = null,
tint = Color(0xFF977EFF)
)
Text(
"Select a reply",
style = MaterialTheme.typography.headlineLarge,
textAlign = TextAlign.Center
)
}
Spacer(modifier = Modifier.height(16.dp))
options.forEach { option ->
SheetButton(
headlineContent = { Text(option) },
leadingContent = {
Icon(
painter = painterResource(R.drawable.ic_reply_24dp),
contentDescription = null
)
},
onClick = { onOptionSelected(option) }
)
}
}
}

View File

@ -18,11 +18,14 @@ import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.LocalContentColor
import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
@ -30,6 +33,7 @@ import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource
@ -44,10 +48,16 @@ import chat.revolt.api.internals.Roles
import chat.revolt.api.internals.has
import chat.revolt.api.routes.channel.deleteMessage
import chat.revolt.api.routes.channel.react
import chat.revolt.api.schemas.Message
import chat.revolt.api.settings.Experiments
import chat.revolt.api.settings.experiments.SmartReplyImpl
import chat.revolt.callbacks.UiCallbacks
import chat.revolt.components.chat.Message
import chat.revolt.components.generic.SheetButton
import chat.revolt.internals.Platform
import com.valentinilk.shimmer.ShimmerBounds
import com.valentinilk.shimmer.rememberShimmer
import com.valentinilk.shimmer.shimmer
import kotlinx.coroutines.launch
@OptIn(ExperimentalMaterial3Api::class)
@ -55,7 +65,8 @@ import kotlinx.coroutines.launch
fun MessageContextSheet(
messageId: String,
onHideSheet: suspend () -> Unit,
onReportMessage: () -> Unit
onReportMessage: () -> Unit,
lastTenMessages: (() -> List<Message>)? = null
) {
val message = RevoltAPI.messageCache[messageId]
if (message == null) {
@ -73,6 +84,48 @@ fun MessageContextSheet(
val clipboardManager = LocalClipboardManager.current
val coroutineScope = rememberCoroutineScope()
var mlKitSmartReplies by remember { mutableStateOf<List<String>?>(null) }
var mlKitSmartRepliesLoading by remember { mutableStateOf(false) }
LaunchedEffect(Unit) {
if (Experiments.useMlKitSmartReplyInApp.isEnabled) {
mlKitSmartReplies = null
mlKitSmartRepliesLoading = true
lastTenMessages?.let { fn ->
SmartReplyImpl.forMessages(
fn.invoke()
) {
mlKitSmartReplies = if (it.ok) {
it.unwrap()
} else {
null
}
mlKitSmartRepliesLoading = false
}
}
}
}
var showMlKitSmartReplySheet by remember { mutableStateOf(false) }
if (showMlKitSmartReplySheet) {
val mlKitSmartReplySheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
ModalBottomSheet(
sheetState = mlKitSmartReplySheetState,
onDismissRequest = {
showMlKitSmartReplySheet = false
}
) {
MessageContentMLKitReplySelectSheet(
options = mlKitSmartReplies ?: emptyList(),
onOptionSelected = {
coroutineScope.launch {
UiCallbacks.replyToMessageWithContent(messageId, it)
onHideSheet()
}
},
)
}
}
var showShareSheet by remember { mutableStateOf(false) }
var showReactSheet by remember { mutableStateOf(false) }
var showDeleteMessageConfirmation by remember { mutableStateOf(false) }
@ -335,6 +388,30 @@ fun MessageContextSheet(
text = stringResource(id = R.string.message_context_sheet_actions_reply),
)
},
trailingContent = {
if (Experiments.useMlKitSmartReplyInApp.isEnabled && (mlKitSmartRepliesLoading || (mlKitSmartReplies?.isNotEmpty() == true))) {
IconButton(onClick = {
coroutineScope.launch {
showMlKitSmartReplySheet = true
}
}) {
Icon(
painter = painterResource(id = R.drawable.ic_creation_24dp),
contentDescription = null,
tint = if (!mlKitSmartRepliesLoading && mlKitSmartReplies != null) {
Color(0xFF977EFF)
} else {
LocalContentColor.current
},
modifier = if (mlKitSmartRepliesLoading) {
Modifier.shimmer(rememberShimmer(ShimmerBounds.View))
} else {
Modifier
}
)
}
}
},
onClick = {
coroutineScope.launch {
UiCallbacks.replyToMessage(messageId)

View File

@ -0,0 +1,9 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:height="24dp"
android:width="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:fillColor="#ffffff"
android:pathData="M19,1L17.74,3.75L15,5L17.74,6.26L19,9L20.25,6.26L23,5L20.25,3.75M9,4L6.5,9.5L1,12L6.5,14.5L9,20L11.5,14.5L17,12L11.5,9.5M19,15L17.74,17.74L15,19L17.74,20.25L19,23L20.25,20.25L23,19L20.25,17.74" />
</vector>