Skip to content

Commit

Permalink
Persist selected account IDs to SavedStateHandle (#8358)
Browse files Browse the repository at this point in the history
* Persist selected account IDs to `SavedStateHandle`

* Update tests

* Save list directly
  • Loading branch information
tillh-stripe committed May 6, 2024
1 parent 299f56e commit a9f7bfd
Show file tree
Hide file tree
Showing 18 changed files with 98 additions and 76 deletions.
8 changes: 8 additions & 0 deletions financial-connections/api/financial-connections.api
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ public abstract interface class com/stripe/android/financialconnections/analytic
public abstract fun onEvent (Lcom/stripe/android/financialconnections/analytics/FinancialConnectionsEvent;)V
}

public final class com/stripe/android/financialconnections/domain/CachedPartnerAccount$Creator : android/os/Parcelable$Creator {
public fun <init> ()V
public final fun createFromParcel (Landroid/os/Parcel;)Lcom/stripe/android/financialconnections/domain/CachedPartnerAccount;
public synthetic fun createFromParcel (Landroid/os/Parcel;)Ljava/lang/Object;
public final fun newArray (I)[Lcom/stripe/android/financialconnections/domain/CachedPartnerAccount;
public synthetic fun newArray (I)[Ljava/lang/Object;
}

public final class com/stripe/android/financialconnections/exception/AppInitializationError : com/stripe/android/core/exception/StripeException {
public static final field $stable I
public fun <init> (Ljava/lang/String;)V
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.stripe.android.financialconnections.di

import android.app.Application
import androidx.lifecycle.SavedStateHandle
import com.stripe.android.core.ApiVersion
import com.stripe.android.core.Logger
import com.stripe.android.core.networking.ApiRequest
Expand Down Expand Up @@ -115,12 +116,14 @@ internal interface FinancialConnectionsSheetNativeModule {
requestExecutor: FinancialConnectionsRequestExecutor,
apiOptions: ApiRequest.Options,
apiRequestFactory: ApiRequest.Factory,
logger: Logger
logger: Logger,
savedStateHandle: SavedStateHandle,
) = FinancialConnectionsAccountsRepository(
requestExecutor = requestExecutor,
apiRequestFactory = apiRequestFactory,
apiOptions = apiOptions,
logger = logger
logger = logger,
savedStateHandle = savedStateHandle,
)

@Singleton
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.stripe.android.financialconnections.domain

import android.os.Parcelable
import com.stripe.android.financialconnections.FinancialConnectionsSheet
import com.stripe.android.financialconnections.model.PartnerAccount
import com.stripe.android.financialconnections.repository.FinancialConnectionsAccountsRepository
import kotlinx.parcelize.Parcelize
import javax.inject.Inject

/**
Expand All @@ -13,7 +15,17 @@ internal class GetCachedAccounts @Inject constructor(
val configuration: FinancialConnectionsSheet.Configuration
) {

suspend operator fun invoke(): List<PartnerAccount> {
suspend operator fun invoke(): List<CachedPartnerAccount> {
return requireNotNull(repository.getCachedAccounts())
}
}

@Parcelize
internal data class CachedPartnerAccount(
val id: String,
val linkedAccountId: String?,
) : Parcelable

internal fun List<PartnerAccount>.toCachedPartnerAccounts(): List<CachedPartnerAccount> {
return map { CachedPartnerAccount(id = it.id, linkedAccountId = it.linkedAccountId) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.stripe.android.financialconnections.domain
import com.stripe.android.financialconnections.FinancialConnectionsSheet
import com.stripe.android.financialconnections.R
import com.stripe.android.financialconnections.model.FinancialConnectionsSessionManifest
import com.stripe.android.financialconnections.model.PartnerAccount
import com.stripe.android.financialconnections.repository.FinancialConnectionsAccountsRepository
import com.stripe.android.financialconnections.repository.FinancialConnectionsManifestRepository
import com.stripe.android.financialconnections.repository.SuccessContentRepository
Expand All @@ -26,7 +25,7 @@ internal class SaveAccountToLink @Inject constructor(
suspend fun new(
email: String,
phoneNumber: String,
selectedAccounts: List<PartnerAccount>,
selectedAccounts: List<CachedPartnerAccount>,
country: String,
shouldPollAccountNumbers: Boolean,
): FinancialConnectionsSessionManifest {
Expand All @@ -45,7 +44,7 @@ internal class SaveAccountToLink @Inject constructor(

suspend fun existing(
consumerSessionClientSecret: String,
selectedAccounts: List<PartnerAccount>,
selectedAccounts: List<CachedPartnerAccount>,
shouldPollAccountNumbers: Boolean,
): FinancialConnectionsSessionManifest {
return ensureReadyAccounts(shouldPollAccountNumbers, selectedAccounts) { selectedAccountIds ->
Expand All @@ -63,7 +62,7 @@ internal class SaveAccountToLink @Inject constructor(

private suspend fun ensureReadyAccounts(
shouldPollAccountNumbers: Boolean,
partnerAccounts: List<PartnerAccount>,
partnerAccounts: List<CachedPartnerAccount>,
action: suspend (Set<String>) -> FinancialConnectionsSessionManifest,
): FinancialConnectionsSessionManifest {
val selectedAccountIds = partnerAccounts.map { it.id }.toSet()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ internal class UpdateCachedAccounts @Inject constructor(
val repository: FinancialConnectionsAccountsRepository
) {

suspend operator fun invoke(
block: (List<PartnerAccount>?) -> List<PartnerAccount>?
) {
val updatedAccounts = block(repository.getCachedAccounts())
repository.updateCachedAccounts(updatedAccounts)
suspend operator fun invoke(accounts: List<PartnerAccount>?) {
repository.updateCachedAccounts(accounts)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.stripe.android.financialconnections.domain.NativeAuthFlowCoordinator
import com.stripe.android.financialconnections.domain.PollAuthorizationSessionAccounts
import com.stripe.android.financialconnections.domain.SaveAccountToLink
import com.stripe.android.financialconnections.domain.SelectAccounts
import com.stripe.android.financialconnections.domain.toCachedPartnerAccounts
import com.stripe.android.financialconnections.features.accountpicker.AccountPickerClickableText.DATA
import com.stripe.android.financialconnections.features.accountpicker.AccountPickerState.SelectionMode
import com.stripe.android.financialconnections.features.accountpicker.AccountPickerState.ViewEffect
Expand Down Expand Up @@ -299,7 +300,7 @@ internal class AccountPickerViewModel @AssistedInject constructor(
// it happens in the AttachPaymentScreen.
saveAccountToLink.existing(
consumerSessionClientSecret = consumerSessionClientSecret,
selectedAccounts = accountsList.data,
selectedAccounts = accountsList.data.toCachedPartnerAccounts(),
shouldPollAccountNumbers = manifest.isDataFlow,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ internal class AttachPaymentViewModel @AssistedInject constructor(
val authSession = requireNotNull(manifest.activeAuthSession)
val activeInstitution = requireNotNull(manifest.activeInstitution)
val accounts = getCachedAccounts()
require(accounts.size == 1)
val id = accounts.first().linkedAccountId
val id = accounts.single().linkedAccountId
val (result, millis) = measureTimeMillis {
pollAttachPaymentAccount(
sync = sync,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ internal class LinkAccountPickerViewModel @AssistedInject constructor(
val payload = requireNotNull(state.payload())

val accounts = payload.selectedPartnerAccounts(state.selectedAccountIds)
updateCachedAccounts { accounts }
updateCachedAccounts(accounts)

// We assume that at this point, all selected accounts have the same next pane.
// Otherwise, the user would have been presented with an update-required bottom
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package com.stripe.android.financialconnections.repository

import androidx.lifecycle.SavedStateHandle
import com.stripe.android.core.Logger
import com.stripe.android.core.networking.ApiRequest
import com.stripe.android.financialconnections.domain.CachedPartnerAccount
import com.stripe.android.financialconnections.domain.toCachedPartnerAccounts
import com.stripe.android.financialconnections.model.FinancialConnectionsSessionManifest
import com.stripe.android.financialconnections.model.InstitutionResponse
import com.stripe.android.financialconnections.model.LinkAccountSessionPaymentAccount
Expand All @@ -16,16 +19,14 @@ import com.stripe.android.financialconnections.network.NetworkConstants.PARAMS_C
import com.stripe.android.financialconnections.network.NetworkConstants.PARAMS_ID
import com.stripe.android.financialconnections.network.NetworkConstants.PARAM_SELECTED_ACCOUNTS
import com.stripe.android.financialconnections.utils.filterNotNullValues
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock

/**
* Repository to centralize reads and writes to the [FinancialConnectionsSessionManifest]
* of the current flow.
*/
internal interface FinancialConnectionsAccountsRepository {

suspend fun getCachedAccounts(): List<PartnerAccount>?
suspend fun getCachedAccounts(): List<CachedPartnerAccount>?

suspend fun updateCachedAccounts(partnerAccountsList: List<PartnerAccount>?)

Expand Down Expand Up @@ -70,13 +71,15 @@ internal interface FinancialConnectionsAccountsRepository {
requestExecutor: FinancialConnectionsRequestExecutor,
apiRequestFactory: ApiRequest.Factory,
apiOptions: ApiRequest.Options,
logger: Logger
logger: Logger,
savedStateHandle: SavedStateHandle,
): FinancialConnectionsAccountsRepository =
FinancialConnectionsAccountsRepositoryImpl(
requestExecutor,
apiRequestFactory,
apiOptions,
logger
logger,
savedStateHandle,
)
}
}
Expand All @@ -85,21 +88,20 @@ private class FinancialConnectionsAccountsRepositoryImpl(
val requestExecutor: FinancialConnectionsRequestExecutor,
val apiRequestFactory: ApiRequest.Factory,
val apiOptions: ApiRequest.Options,
val logger: Logger
val logger: Logger,
private val savedStateHandle: SavedStateHandle,
) : FinancialConnectionsAccountsRepository {

/**
* Ensures that [cachedAccounts] accesses via [getCachedAccounts] suspend until
* current writes are running.
*/
val mutex = Mutex()
private var cachedAccounts: List<PartnerAccount>? = null

override suspend fun getCachedAccounts(): List<PartnerAccount>? =
mutex.withLock { cachedAccounts }
override suspend fun getCachedAccounts(): List<CachedPartnerAccount>? {
return savedStateHandle[CachedPartnerAccountsKey]
}

override suspend fun updateCachedAccounts(partnerAccountsList: List<PartnerAccount>?) =
mutex.withLock { cachedAccounts = partnerAccountsList }
override suspend fun updateCachedAccounts(partnerAccountsList: List<PartnerAccount>?) {
updateCachedAccounts(
source = "updateCachedAccounts",
accounts = partnerAccountsList.orEmpty(),
)
}

override suspend fun postAuthorizationSessionAccounts(
clientSecret: String,
Expand Down Expand Up @@ -230,10 +232,13 @@ private class FinancialConnectionsAccountsRepositoryImpl(
accounts: List<PartnerAccount>
) {
logger.debug("updating local partner accounts from $source")
cachedAccounts = accounts
val cachedAccounts = accounts.toCachedPartnerAccounts()
savedStateHandle[CachedPartnerAccountsKey] = cachedAccounts
}

companion object {
private const val CachedPartnerAccountsKey = "CachedPartnerAccounts"

internal const val accountsSessionUrl: String =
"${ApiRequest.API_HOST}/v1/connections/auth_sessions/accounts"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.stripe.android.financialconnections

import com.stripe.android.financialconnections.domain.CachedPartnerAccount
import com.stripe.android.financialconnections.model.FinancialConnectionsAccount
import com.stripe.android.financialconnections.model.FinancialConnectionsAccountList
import com.stripe.android.financialconnections.model.FinancialConnectionsAuthorizationSession
Expand Down Expand Up @@ -114,6 +115,18 @@ internal object ApiKeyFixtures {
supportedPaymentMethodTypes = listOf(FinancialConnectionsAccount.SupportedPaymentMethodTypes.US_BANK_ACCOUNT)
)

fun cachedPartnerAccounts(): List<CachedPartnerAccount> {
return listOf(
CachedPartnerAccount(id = "id_1", linkedAccountId = "linked_id_1"),
CachedPartnerAccount(id = "id_2", linkedAccountId = "linked_id_2"),
)
}

fun cachedPartnerAccount() = CachedPartnerAccount(
id = "id",
linkedAccountId = "linked_id",
)

fun consumerSession() = ConsumerSession(
clientSecret = "clientSecret",
emailAddress = "[email protected]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.stripe.android.financialconnections.domain
import androidx.lifecycle.SavedStateHandle
import com.google.common.truth.Truth.assertThat
import com.stripe.android.financialconnections.ApiKeyFixtures
import com.stripe.android.financialconnections.ApiKeyFixtures.partnerAccount
import com.stripe.android.financialconnections.ApiKeyFixtures.sessionManifest
import com.stripe.android.financialconnections.FinancialConnectionsSheet
import com.stripe.android.financialconnections.R
Expand All @@ -30,10 +29,7 @@ internal class SaveAccountToLinkTest {
fun `Polls account numbers if requested to do so`() = runTest(testDispatcher) {
val polledAccountIds = mutableSetOf<String>()

val partnerAccounts = listOf(
partnerAccount().copy(id = "id_1", linkedAccountId = "lid_1"),
partnerAccount().copy(id = "id_2", linkedAccountId = "lid_2"),
)
val partnerAccounts = ApiKeyFixtures.cachedPartnerAccounts()

val accountsRepository = mockAccountsRepository(
onPollAccountNumbers = polledAccountIds::addAll,
Expand All @@ -49,17 +45,14 @@ internal class SaveAccountToLinkTest {
shouldPollAccountNumbers = true,
)

assertThat(polledAccountIds).containsExactly("lid_1", "lid_2")
assertThat(polledAccountIds).containsExactly("linked_id_1", "linked_id_2")
}

@Test
fun `Skips polling account numbers if not requested to do so`() = runTest(testDispatcher) {
val polledAccountIds = mutableSetOf<String>()

val partnerAccounts = listOf(
partnerAccount().copy(id = "id_1", linkedAccountId = "lid_1"),
partnerAccount().copy(id = "id_2", linkedAccountId = "lid_2"),
)
val partnerAccounts = ApiKeyFixtures.cachedPartnerAccounts()

val accountsRepository = mockAccountsRepository(
onPollAccountNumbers = polledAccountIds::addAll,
Expand All @@ -82,10 +75,7 @@ internal class SaveAccountToLinkTest {
fun `Disables networking if polling account numbers fails`() = runTest(testDispatcher) {
var disabledNetworking = false

val partnerAccounts = listOf(
partnerAccount().copy(id = "id_1", linkedAccountId = "lid_1"),
partnerAccount().copy(id = "id_2", linkedAccountId = "lid_2"),
)
val partnerAccounts = ApiKeyFixtures.cachedPartnerAccounts()

val repository = mockManifestRepository(
onDisabledNetworking = { disabledNetworking = true },
Expand Down Expand Up @@ -115,10 +105,7 @@ internal class SaveAccountToLinkTest {

@Test
fun `Sets custom success message if polling account numbers fails`() = runTest(testDispatcher) {
val partnerAccounts = listOf(
partnerAccount().copy(id = "id_1", linkedAccountId = "lid_1"),
partnerAccount().copy(id = "id_2", linkedAccountId = "lid_2"),
)
val partnerAccounts = ApiKeyFixtures.cachedPartnerAccounts()

val accountsRepository = mockAccountsRepository(
onPollAccountNumbers = { error("This is failing") },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.stripe.android.financialconnections.ApiKeyFixtures.sessionManifest
import com.stripe.android.financialconnections.ApiKeyFixtures.syncResponse
import com.stripe.android.financialconnections.CoroutineTestRule
import com.stripe.android.financialconnections.TestFinancialConnectionsAnalyticsTracker
import com.stripe.android.financialconnections.domain.CachedPartnerAccount
import com.stripe.android.financialconnections.domain.GetCachedConsumerSession
import com.stripe.android.financialconnections.domain.GetOrFetchSync
import com.stripe.android.financialconnections.domain.NativeAuthFlowCoordinator
Expand Down Expand Up @@ -300,7 +301,7 @@ internal class AccountPickerViewModelTest {

verify(saveAccountToLink).existing(
consumerSessionClientSecret = consumerSession.clientSecret,
selectedAccounts = accounts.data,
selectedAccounts = accounts.data.map { CachedPartnerAccount(it.id, it.linkedAccountId) },
shouldPollAccountNumbers = true,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import com.stripe.android.financialconnections.model.Image
import com.stripe.android.financialconnections.model.InstitutionResponse
import com.stripe.android.financialconnections.model.NetworkedAccount
import com.stripe.android.financialconnections.model.NetworkedAccountsList
import com.stripe.android.financialconnections.model.PartnerAccount
import com.stripe.android.financialconnections.model.ReturningNetworkingUserAccountPicker
import com.stripe.android.financialconnections.model.TextUpdate
import com.stripe.android.financialconnections.navigation.Destination.LinkStepUpVerification
Expand All @@ -32,7 +31,6 @@ import kotlinx.coroutines.test.runTest
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.mock
import org.mockito.kotlin.verify
import org.mockito.kotlin.verifyNoInteractions
Expand Down Expand Up @@ -157,10 +155,7 @@ class LinkAccountPickerViewModelTest {
viewModel.onAccountClick(selectedAccount)
viewModel.onSelectAccountsClick()

with(argumentCaptor<(List<PartnerAccount>?) -> List<PartnerAccount>?>()) {
verify(updateCachedAccounts).invoke(capture())
assertThat(firstValue(null)).isEqualTo(listOf(selectedAccount))
}
verify(updateCachedAccounts).invoke(listOf(selectedAccount))
val destination = accounts.data.first().nextPaneOnSelection!!.destination
navigationManager.assertNavigatedTo(destination, Pane.LINK_ACCOUNT_PICKER)
}
Expand Down Expand Up @@ -200,10 +195,7 @@ class LinkAccountPickerViewModelTest {
viewModel.onAccountClick(selectedAccount)
viewModel.onSelectAccountsClick()

with(argumentCaptor<(List<PartnerAccount>?) -> List<PartnerAccount>?>()) {
verify(updateCachedAccounts).invoke(capture())
assertThat(firstValue(null)).isEqualTo(listOf(selectedAccount))
}
verify(updateCachedAccounts).invoke(listOf(selectedAccount))
verifyNoInteractions(selectNetworkedAccounts)
navigationManager.assertNavigatedTo(LinkStepUpVerification, Pane.LINK_ACCOUNT_PICKER)
}
Expand Down

0 comments on commit a9f7bfd

Please sign in to comment.