Skip to main content
RxJava to Kotlin coroutines

RxJava to Kotlin coroutines

·1640 words·8 mins

OK, I know this is a bit of a clickbait-y title but it’s the best I could come up with. This post is a summary of how I refactored an app which uses RxJava pretty heavily, to one which uses Kotlin coroutines too. Specifically I’ll talk about switching out Single/ Maybe/ Completable sources to coroutines.

The app
#

To start off, a bit of an introduction to how the app is architected. The majority of my business logic is built into things called ‘ calls’:

interface Call<in Param, Output> {
    fun data(param: Param): Flowable<Output>
    fun refresh(param: Param): Completable
}

As you can see, each call has two main responsibilities:

  1. Its data() method, which exposes a stream of data relevant to the call. This returns a Flowable, and most of the time is just a Flowable from a Room DAO. A ViewModel then subscribes to this and passes the data it to the UI, etc.
  2. Its refresh() method. Hopefully pretty self explanatory, it triggers a refresh of the data. Most of the implementations will fetch from the network, map entities and then update the Room database. This is currently returns a Completable, which the ViewModel will subscribe to start the ‘action’.

So where did I plan on fitting coroutines into this?
#

My goal was to make refresh() a suspending function:

interface Call<in Param, Output> {
    fun data(param: Param): Flowable<Output>
    suspend fun refresh(param: Param)
}

Once I made that change all of the call implementations started complaining since the function signature changed. Luckily the kotlinx-coroutines-rx2 extension library provide extension methods to RxJava single-like types allowing us to await() their completion. So a quick pasting of .await() at the end of every implementation’s Rx chain fixed the build.

We now have a very crude and stupid use of coroutines! But hey it’s a start and everything still works.

The next step was to start converting all of the code underneath refresh() coroutine aware and remove RxJava where it is not needed. At this point you might be wondering what I mean by ‘coroutine aware’, well it’s all to do with threads.

Threading
#

With RxJava, I have different Schedulers for different kinds of tasks. This was implemented with a data class which is injected anywhere which chains Rx operators together.

data class AppRxSchedulers( val database: Scheduler, val disk: Scheduler, val network: Scheduler, val main: Scheduler )

@Singleton @Provides fun provideRxSchedulers() = AppRxSchedulers( database = Schedulers.single(), disk = Schedulers.io(), network = Schedulers.io(), main = AndroidSchedulers.mainThread() )

The most important one in my opinion was the database scheduler. This is because I want to force single-threaded reads, ensuring data integrity and not locking SQLite.

With coroutines I wanted to do the same, ensuring that both RxJava and Coroutines were using the same thread pools. This turned out to relatively easy using the kotlinx-coroutines-rx2 extension library. It adds an extension method on Scheduler which wraps it into a CoroutineDispatcher. Using that I convert my schedulers into dispatchers and inject them too.

data class AppCoroutineDispatchers(
    val database: CoroutineDispatcher,
    val disk: CoroutineDispatcher,
    val network: CoroutineDispatcher,
    val main: CoroutineDispatcher
)

@Singleton
@Provides
fun provideDispatchers(schedulers: AppRxSchedulers) = 
    AppCoroutineDispatchers(
        database = schedulers.database.asCoroutineDispatcher(),
        disk = schedulers.disk.asCoroutineDispatcher(),
        network = schedulers.network.asCoroutineDispatcher(),
        main = UI
    )

As you can see, I’m currently using the RxJava scheduler as the source. In the future, I’ll probably swap that around so that schedulers are derived from the dispatchers.

Changing threads
#

So I have my schedulers and dispatchers sharing the same threads, but what about using them in our operations?

RxJava makes it really easy to chain different threaded observers together with its subscribeOn() and observeOn() methods. Here’s an example of a refresh() method where I used my network scheduler to fetch a network response and map it to an internal entity, and then use the database scheduler to store the result.

override fun refresh(): Completable {
    return trakt.users().profile(UserSlug.ME).toRxSingle()
            .subscribeOn(schedulers.network)
            .map {
                TraktUser(username = it.username, name = it.name)
            }
            .observeOn(schedulers.database)
            .doOnSuccess {
                dao.insert(it)
            }
            .toCompletable()
}

override fun data(): Flowable<TraktUser> {
    return dao.getTraktUser()
            .subscribeOn(schedulers.database)
}

Hopefully you can see that this is pretty standard RxJava code. Now I need to convert this to coroutines.

First attempt
#

After you’ve read the coroutines guide, you’ll probably have two functions in your mind: launch() and async(). As you can probably guess, my first attempt centred on using these to chain things together:

override suspend fun refresh(param: Unit) {
    // Fetch network response on network dispatcher
    val networkResponse = async(dispatchers.network) {
        trakt.users().profile(UserSlug.ME).execute().body()
    }.await() // await the result

    // Map to our entity
    val entity = TraktUser(
        username = networkResponse.username,
        name = networkResponse.name
    )

    // Save to the database on the database dispatcher
    async(dispatchers.database) {
        dao.insert(entity)
    }.await() // Wait for the insert to finish
}

This actually works but is a bit wasteful. Here we’re actually starting three coroutines: 1) network call, 2) database call and 3) the host coroutine to invoke data() (in the ViewModel).

I’m sure you can imagine a more complex Rx chain which does things like flatMaping an Iterable or other crazy things. The number of coroutines you’ll be creating grows significantly when it is not always needed, like in the example above. Everything we’re doing here is sequential so there’s no need to start a new coroutine. What we need is a way to just change dispatcher, and luckily the coroutines team have provided us with a way: withContext().

Second attempt
#

I stumbled upon withContext() in a small code sample in the coroutines guide. My second (and current) attempt concentrated on using withContext() as a replacement for subscribeOn() and observeOn(), since it does exactly what we want:

This function immediately applies dispatcher from the new context, shifting execution of the block into the different thread inside the block, and back when it completes.

With that in mind, the example becomes:

override suspend fun refresh(param: Unit) {
    // Fetch network response on network dispatcher
    val networkResponse = withContext(dispatchers.network) {
        trakt.users().profile(UserSlug.ME).execute().body()
    }

    // Map to our entity
    val entity = TraktUser(
        username = networkResponse.username,
        name = networkResponse.name
    )

    // Save to the database on the database dispatcher
    withContext(dispatchers.database) {
        dao.insert(entity)
    }
}

You can see that we now we’ve removed the async calls, meaning we don’t create any new coroutines. We just move the host coroutine to use our specific dispatcher (and thread).

But the docs say that coroutines are really lightweight. Why can’t I just async/launch?
#

Coroutines are very lightweight, but creating them is still a cost. You should remember that on Android we’re running on a resource constrained system, so we need to do everything possible to reduce our footprint. Using withContext fulfils our needs, with a single function call and minimal object allocation, compared to creating a new coroutine with async or launch.

There’s also the fact that async and launch are meant for tasks which are asynchronous. The majority of the time you’ll have a main task which is asynchronous, but within that you’ll be invoking synchronous sub-tasks. By using async and launch you’re forced to do the additional await() or join(), which is unnecessary complexity to read.

On the other hand, if your sub-tasks are unrelated, making them run concurrently with async is a valid way to go.

What about more complex Rx chains?
#

Here’s an example chain which strikes confusion into me whenever I look at it:

override fun refresh(param: Unit): Completable {
    trakt.users().watchedShows(UserSlug.ME).toRxSingle()
            .subscribeOn(schedulers.network)
            .toFlowable()
            .flatMapIterable { it }
            .flatMapSingle {
                showFetcher.load(it)
            }
            .toList()
            .observeOn(schedulers.database)
            .doOnSuccess {
                databaseTransactionRunner.runInTransaction {
                    dao.deleteAll()
                    it.forEach { dao.insert(it) }
                }
            }
            .toCompletable()
}

When you break it down, it’s not actually that much more complex rather the example we’ve been using above. The big difference is that it is dealing in a collection of entities rather than a single entity. To do that, it uses Flowable’s flatMapIterable() to fan out maps for each item, and then later uses toList() to combine all the results back together again into a list, which then gets persisted to the database.

I am actually using a different class (showFetcher) to provide the operator for the fan-out, in this case it returns a Single. Ignore that for now.

What we’re actually describing here is a parallel map, where each map() is ran concurrently. JDK 8 provides something similar with list.parallel().map(/ map function /).collect(toList()). We could use that functionality, but we would not be using coroutines then!

There’s no built-in version of this in Kotlin using coroutines (that I could find), but luckily it is simple to implement:

suspend fun <A, B> Collection<A>.parallelMap(
    context: CoroutineContext = DefaultDispatcher,
    block: suspend (A) -> B
): Collection<B> {
    return map {
        // Use async to start a coroutine for each item
        async(context) {
            block(it)
        }
    }.map {
        // We now have a map of Deferred&lt;T&gt; so we await() each            
        it.await()
    }
}

By using the parallelMap(), our complex RxJava chain becomes the following:

override suspend fun refresh(param: Unit) {
    val networkResponse = withContext(dispatchers.network) {
        trakt.users().watchedShows(UserSlug.ME).execute().body()
    }

    val shows = networkResponse.parallelMap {
        showFetcher.load(it)
    }

    // Now save the list to the database
    withContext(dispatchers.database) {
        databaseTransactionRunner.runInTransaction {
            dao.deleteAll()
            shows.forEach { dao.insert(it) }
        }
    }
}

Hopefully you can see that this is now a lot clearer to read. The showFetcher class still needs to be converted but it can _a_wait for now. 🤦

If you’re interested, you can see the PR which converts the app to coroutines here:

Moving from RxJava to coroutines by chrisbanes · Pull Request #135 · chrisbanes/tivi · GitHub
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
github.com

Next steps
#

Hopefully you can see that it’s actually relatively easy to switch to coroutines from RxJava’s Single/Maybe/Completable.

For now, I’m still using RxJava for streaming observables, but I may move to solely to LiveData or coroutine channels. That’s a task for another time though.