feat(experiment): mlkit smart reply

Signed-off-by: Infi <infi@infi.sh>
This commit is contained in:
Infi 2024-10-29 02:26:44 +01:00
parent 4c2bf8703e
commit c0af196da7
12 changed files with 303 additions and 1 deletions

View File

@ -292,6 +292,9 @@ dependencies {
// Shimmer - loading animations // Shimmer - loading animations
implementation "com.valentinilk.shimmer:compose-shimmer:1.3.1" implementation "com.valentinilk.shimmer:compose-shimmer:1.3.1"
// MLKit - Machine Learning
implementation 'com.google.android.gms:play-services-mlkit-smart-reply:16.0.0-beta1'
} }
sqldelight { sqldelight {

View File

@ -50,6 +50,10 @@
android:name="io.sentry.auto-init" android:name="io.sentry.auto-init"
android:value="false" /> android:value="false" />
<meta-data
android:name="com.google.mlkit.vision.DEPENDENCIES"
android:value="smart_reply" />
<meta-data <meta-data
android:name="com.google.firebase.messaging.default_notification_icon" android:name="com.google.firebase.messaging.default_notification_icon"
android:resource="@drawable/ic_notification_monochrome" /> android:resource="@drawable/ic_notification_monochrome" />

View File

@ -27,6 +27,7 @@ class ExperimentInstance(default: Boolean) {
*/ */
object Experiments { object Experiments {
val useKotlinBasedMarkdownRenderer = ExperimentInstance(false) val useKotlinBasedMarkdownRenderer = ExperimentInstance(false)
val useMlKitSmartReplyInApp = ExperimentInstance(false)
suspend fun hydrateWithKv() { suspend fun hydrateWithKv() {
val kvStorage = KVStorage(RevoltApplication.instance) val kvStorage = KVStorage(RevoltApplication.instance)
@ -40,5 +41,8 @@ object Experiments {
useKotlinBasedMarkdownRenderer.setEnabled( useKotlinBasedMarkdownRenderer.setEnabled(
kvStorage.getBoolean("exp/useKotlinBasedMarkdownRenderer") ?: false kvStorage.getBoolean("exp/useKotlinBasedMarkdownRenderer") ?: false
) )
useMlKitSmartReplyInApp.setEnabled(
kvStorage.getBoolean("exp/useMlKitSmartReplyInApp") ?: false
)
} }
} }

View File

@ -0,0 +1,66 @@
package chat.revolt.api.settings.experiments
import android.util.Log
import chat.revolt.api.RevoltAPI
import chat.revolt.api.internals.ULID
import chat.revolt.api.schemas.Message
import chat.revolt.api.schemas.RsResult
import com.google.mlkit.nl.smartreply.SmartReply
import com.google.mlkit.nl.smartreply.TextMessage
object SmartReplyImpl {
val client = SmartReply.getClient()
fun forMessages(
messages: List<Message>,
onResult: (RsResult<List<String>, Exception>) -> Unit
) {
if (messages.size > 10) {
onResult(RsResult.err(IllegalArgumentException("Too many messages")))
return
}
Log.d("SmartReplyImpl", "Creating conversation for ${messages.size} messages")
val conversation = messages.map {
if (it.author == RevoltAPI.selfId) {
if (it.id == null) {
Log.w("SmartReplyImpl", "Message ID is null in local user message, skipping")
return@map null
}
if (it.content?.isEmpty() == true || it.content == null) {
Log.w(
"SmartReplyImpl",
"Message content is null or empty in local user message, skipping"
)
return@map null
}
TextMessage.createForLocalUser(it.content, ULID.asTimestamp(it.id))
} else {
if (it.id == null || it.author == null) {
Log.w(
"SmartReplyImpl",
"Message ID or author is null in remote user message, skipping"
)
return@map null
}
if (it.content?.isEmpty() == true || it.content == null) {
Log.w(
"SmartReplyImpl",
"Message content is null or empty in remote user message, skipping"
)
return@map null
}
TextMessage.createForRemoteUser(it.content, ULID.asTimestamp(it.id), it.author)
}
}.filterNotNull().reversed()
Log.d("SmartReplyImpl", "Suggesting replies for ${conversation.size} messages")
client.suggestReplies(conversation).addOnSuccessListener {
onResult(RsResult.ok(it.suggestions.map { s -> s.text }))
}.addOnFailureListener {
onResult(RsResult.err(it))
}
}
}

View File

@ -4,6 +4,7 @@ import kotlinx.coroutines.flow.MutableSharedFlow
sealed class UiCallback { sealed class UiCallback {
data class ReplyToMessage(val messageId: String) : UiCallback() data class ReplyToMessage(val messageId: String) : UiCallback()
data class ReplyToMessageWithContent(val messageId: String, val content: String) : UiCallback()
data class EditMessage(val messageId: String) : UiCallback() data class EditMessage(val messageId: String) : UiCallback()
} }
@ -14,6 +15,10 @@ object UiCallbacks {
uiCallbackFlow.emit(UiCallback.ReplyToMessage(messageId)) uiCallbackFlow.emit(UiCallback.ReplyToMessage(messageId))
} }
suspend fun replyToMessageWithContent(messageId: String, content: String) {
uiCallbackFlow.emit(UiCallback.ReplyToMessageWithContent(messageId, content))
}
suspend fun editMessage(messageId: String) { suspend fun editMessage(messageId: String) {
uiCallbackFlow.emit(UiCallback.EditMessage(messageId)) uiCallbackFlow.emit(UiCallback.EditMessage(messageId))
} }

View File

@ -19,6 +19,7 @@ fun SheetButton(
onClick: () -> Unit, onClick: () -> Unit,
modifier: Modifier = Modifier, modifier: Modifier = Modifier,
supportingContent: @Composable (() -> Unit)? = null, supportingContent: @Composable (() -> Unit)? = null,
trailingContent: @Composable (() -> Unit)? = null,
dangerous: Boolean = false dangerous: Boolean = false
) { ) {
Box( Box(
@ -62,6 +63,19 @@ fun SheetButton(
this() this()
} }
} }
},
trailingContent = {
trailingContent?.run {
CompositionLocalProvider(
value = if (dangerous) {
LocalContentColor provides MaterialTheme.colorScheme.error
} else {
LocalContentColor provides MaterialTheme.colorScheme.onSurface
}
) {
this()
}
}
} }
) )
} }

View File

@ -134,6 +134,7 @@ import kotlinx.coroutines.launch
import kotlinx.datetime.Instant import kotlinx.datetime.Instant
import java.io.File import java.io.File
import kotlin.math.max import kotlin.math.max
import kotlin.math.min
sealed class ChannelScreenItem { sealed class ChannelScreenItem {
data class RegularMessage(val message: Message) : ChannelScreenItem() data class RegularMessage(val message: Message) : ChannelScreenItem()
@ -372,6 +373,22 @@ fun ChannelScreen(
scope.launch { scope.launch {
ActionChannel.send(Action.ReportMessage(messageContextSheetTarget)) ActionChannel.send(Action.ReportMessage(messageContextSheetTarget))
} }
},
lastTenMessages = {
// First we find the message in viewModel.items
val index = viewModel.items.filterIsInstance<ChannelScreenItem.RegularMessage>()
.indexOfFirst { it.message.id == messageContextSheetTarget }
Log.d("ChannelScreen", "We have index $index")
Log.d("ChannelScreen", "Items.len ${viewModel.items.size}")
// Then we take the last 10 messages before it. We take care to not go out of bounds.
val messages =
viewModel.items.filterIsInstance<ChannelScreenItem.RegularMessage>()
.subList(index, min(index + 5, (viewModel.items.size - 1)))
.map { it.message }
messages
} }
) )
} }

View File

@ -729,6 +729,21 @@ class ChannelScreenViewModel @Inject constructor(
this@ChannelScreenViewModel.draftAttachments.clear() this@ChannelScreenViewModel.draftAttachments.clear()
draftReplyTo.clear() draftReplyTo.clear()
} }
is UiCallback.ReplyToMessageWithContent -> {
val message = items.find { m ->
m is ChannelScreenItem.RegularMessage && m.message.id == it.messageId
} as? ChannelScreenItem.RegularMessage ?: return@onEach
val shouldMention = kvStorage.getBoolean("mentionOnReply") ?: false
draftReplyTo.add(
SendMessageReply(
message.message.id ?: return@onEach,
shouldMention
)
)
putDraftContent(it.content)
}
} }
}.catch { }.catch {
Log.e("ChannelScreen", "Failed to receive UI callback", it) Log.e("ChannelScreen", "Failed to receive UI callback", it)

View File

@ -30,6 +30,7 @@ class ExperimentsSettingsScreenViewModel : ViewModel() {
fun init() { fun init() {
viewModelScope.launch { viewModelScope.launch {
useKotlinMdRendererChecked.value = Experiments.useKotlinBasedMarkdownRenderer.isEnabled useKotlinMdRendererChecked.value = Experiments.useKotlinBasedMarkdownRenderer.isEnabled
useMlKitSmartReplyInAppChecked.value = Experiments.useMlKitSmartReplyInApp.isEnabled
} }
} }
@ -42,6 +43,7 @@ class ExperimentsSettingsScreenViewModel : ViewModel() {
} }
val useKotlinMdRendererChecked = mutableStateOf(false) val useKotlinMdRendererChecked = mutableStateOf(false)
val useMlKitSmartReplyInAppChecked = mutableStateOf(false)
fun setUseKotlinMdRendererChecked(value: Boolean) { fun setUseKotlinMdRendererChecked(value: Boolean) {
viewModelScope.launch { viewModelScope.launch {
@ -50,6 +52,14 @@ class ExperimentsSettingsScreenViewModel : ViewModel() {
useKotlinMdRendererChecked.value = value useKotlinMdRendererChecked.value = value
} }
} }
fun setUseMlKitSmartReplyInAppChecked(value: Boolean) {
viewModelScope.launch {
kv.set("exp/useMlKitSmartReplyInApp", value)
Experiments.useMlKitSmartReplyInApp.setEnabled(value)
useMlKitSmartReplyInAppChecked.value = value
}
}
} }
@Composable @Composable
@ -83,6 +93,22 @@ fun ExperimentsSettingsScreen(
modifier = Modifier.clickable { viewModel.setUseKotlinMdRendererChecked(!viewModel.useKotlinMdRendererChecked.value) } modifier = Modifier.clickable { viewModel.setUseKotlinMdRendererChecked(!viewModel.useKotlinMdRendererChecked.value) }
) )
ListItem(
headlineContent = {
Text("Smart Reply Suggestions (In-App)")
},
supportingContent = {
Text("Use a machine learning model to suggest replies to messages from the message sheet.")
},
trailingContent = {
Switch(
checked = viewModel.useMlKitSmartReplyInAppChecked.value,
onCheckedChange = viewModel::setUseMlKitSmartReplyInAppChecked
)
},
modifier = Modifier.clickable { viewModel.setUseMlKitSmartReplyInAppChecked(!viewModel.useMlKitSmartReplyInAppChecked.value) }
)
Subcategory( Subcategory(
title = { title = {
Text("Disable experiments") Text("Disable experiments")

View File

@ -0,0 +1,62 @@
package chat.revolt.sheets
import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import chat.revolt.R
import chat.revolt.components.generic.SheetButton
@Composable
fun MessageContentMLKitReplySelectSheet(
options: List<String>,
onOptionSelected: (String) -> Unit
) {
Column(
modifier = Modifier
.verticalScroll(rememberScrollState())
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(16.dp),
modifier = Modifier.fillMaxWidth()
) {
Icon(
painter = painterResource(R.drawable.ic_creation_24dp),
contentDescription = null,
tint = Color(0xFF977EFF)
)
Text(
"Select a reply",
style = MaterialTheme.typography.headlineLarge,
textAlign = TextAlign.Center
)
}
Spacer(modifier = Modifier.height(16.dp))
options.forEach { option ->
SheetButton(
headlineContent = { Text(option) },
leadingContent = {
Icon(
painter = painterResource(R.drawable.ic_reply_24dp),
contentDescription = null
)
},
onClick = { onOptionSelected(option) }
)
}
}
}

View File

@ -18,11 +18,14 @@ import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton
import androidx.compose.material3.LocalContentColor
import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.ModalBottomSheet
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.material3.TextButton import androidx.compose.material3.TextButton
import androidx.compose.material3.rememberModalBottomSheetState import androidx.compose.material3.rememberModalBottomSheetState
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember import androidx.compose.runtime.remember
@ -30,6 +33,7 @@ import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalClipboardManager import androidx.compose.ui.platform.LocalClipboardManager
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.painterResource import androidx.compose.ui.res.painterResource
@ -44,10 +48,16 @@ import chat.revolt.api.internals.Roles
import chat.revolt.api.internals.has import chat.revolt.api.internals.has
import chat.revolt.api.routes.channel.deleteMessage import chat.revolt.api.routes.channel.deleteMessage
import chat.revolt.api.routes.channel.react import chat.revolt.api.routes.channel.react
import chat.revolt.api.schemas.Message
import chat.revolt.api.settings.Experiments
import chat.revolt.api.settings.experiments.SmartReplyImpl
import chat.revolt.callbacks.UiCallbacks import chat.revolt.callbacks.UiCallbacks
import chat.revolt.components.chat.Message import chat.revolt.components.chat.Message
import chat.revolt.components.generic.SheetButton import chat.revolt.components.generic.SheetButton
import chat.revolt.internals.Platform import chat.revolt.internals.Platform
import com.valentinilk.shimmer.ShimmerBounds
import com.valentinilk.shimmer.rememberShimmer
import com.valentinilk.shimmer.shimmer
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@ -55,7 +65,8 @@ import kotlinx.coroutines.launch
fun MessageContextSheet( fun MessageContextSheet(
messageId: String, messageId: String,
onHideSheet: suspend () -> Unit, onHideSheet: suspend () -> Unit,
onReportMessage: () -> Unit onReportMessage: () -> Unit,
lastTenMessages: (() -> List<Message>)? = null
) { ) {
val message = RevoltAPI.messageCache[messageId] val message = RevoltAPI.messageCache[messageId]
if (message == null) { if (message == null) {
@ -73,6 +84,48 @@ fun MessageContextSheet(
val clipboardManager = LocalClipboardManager.current val clipboardManager = LocalClipboardManager.current
val coroutineScope = rememberCoroutineScope() val coroutineScope = rememberCoroutineScope()
var mlKitSmartReplies by remember { mutableStateOf<List<String>?>(null) }
var mlKitSmartRepliesLoading by remember { mutableStateOf(false) }
LaunchedEffect(Unit) {
if (Experiments.useMlKitSmartReplyInApp.isEnabled) {
mlKitSmartReplies = null
mlKitSmartRepliesLoading = true
lastTenMessages?.let { fn ->
SmartReplyImpl.forMessages(
fn.invoke()
) {
mlKitSmartReplies = if (it.ok) {
it.unwrap()
} else {
null
}
mlKitSmartRepliesLoading = false
}
}
}
}
var showMlKitSmartReplySheet by remember { mutableStateOf(false) }
if (showMlKitSmartReplySheet) {
val mlKitSmartReplySheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true)
ModalBottomSheet(
sheetState = mlKitSmartReplySheetState,
onDismissRequest = {
showMlKitSmartReplySheet = false
}
) {
MessageContentMLKitReplySelectSheet(
options = mlKitSmartReplies ?: emptyList(),
onOptionSelected = {
coroutineScope.launch {
UiCallbacks.replyToMessageWithContent(messageId, it)
onHideSheet()
}
},
)
}
}
var showShareSheet by remember { mutableStateOf(false) } var showShareSheet by remember { mutableStateOf(false) }
var showReactSheet by remember { mutableStateOf(false) } var showReactSheet by remember { mutableStateOf(false) }
var showDeleteMessageConfirmation by remember { mutableStateOf(false) } var showDeleteMessageConfirmation by remember { mutableStateOf(false) }
@ -335,6 +388,30 @@ fun MessageContextSheet(
text = stringResource(id = R.string.message_context_sheet_actions_reply), text = stringResource(id = R.string.message_context_sheet_actions_reply),
) )
}, },
trailingContent = {
if (Experiments.useMlKitSmartReplyInApp.isEnabled && (mlKitSmartRepliesLoading || (mlKitSmartReplies?.isNotEmpty() == true))) {
IconButton(onClick = {
coroutineScope.launch {
showMlKitSmartReplySheet = true
}
}) {
Icon(
painter = painterResource(id = R.drawable.ic_creation_24dp),
contentDescription = null,
tint = if (!mlKitSmartRepliesLoading && mlKitSmartReplies != null) {
Color(0xFF977EFF)
} else {
LocalContentColor.current
},
modifier = if (mlKitSmartRepliesLoading) {
Modifier.shimmer(rememberShimmer(ShimmerBounds.View))
} else {
Modifier
}
)
}
}
},
onClick = { onClick = {
coroutineScope.launch { coroutineScope.launch {
UiCallbacks.replyToMessage(messageId) UiCallbacks.replyToMessage(messageId)

View File

@ -0,0 +1,9 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:height="24dp"
android:width="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:fillColor="#ffffff"
android:pathData="M19,1L17.74,3.75L15,5L17.74,6.26L19,9L20.25,6.26L23,5L20.25,3.75M9,4L6.5,9.5L1,12L6.5,14.5L9,20L11.5,14.5L17,12L11.5,9.5M19,15L17.74,17.74L15,19L17.74,20.25L19,23L20.25,20.25L23,19L20.25,17.74" />
</vector>