Initial prediction implementation

Test: Not yet implemented
Relnote: Initial implementation
Change-Id: I9730c55b677ad82797e16f327dbf8fad0c20f4e1
diff --git a/docs-tip-of-tree/build.gradle b/docs-tip-of-tree/build.gradle
index 6550261..890399e 100644
--- a/docs-tip-of-tree/build.gradle
+++ b/docs-tip-of-tree/build.gradle
@@ -175,6 +175,7 @@
     samples(project(":hilt:hilt-navigation-compose-samples"))
     docs(project(":hilt:hilt-navigation-fragment"))
     docs(project(":hilt:hilt-work"))
+    docs(project(":input:input-motionprediction"))
     docs(project(":interpolator:interpolator"))
     docs(project(":javascriptengine:javascriptengine"))
     docs(project(":metrics:metrics-performance"))
diff --git a/input/OWNERS b/input/OWNERS
new file mode 100644
index 0000000..492d141
--- /dev/null
+++ b/input/OWNERS
@@ -0,0 +1,3 @@
[email protected]
[email protected]
[email protected]
\ No newline at end of file
diff --git a/input/input-motionprediction/api/current.txt b/input/input-motionprediction/api/current.txt
new file mode 100644
index 0000000..a2cdf99
--- /dev/null
+++ b/input/input-motionprediction/api/current.txt
@@ -0,0 +1,12 @@
+// Signature format: 4.0
+package androidx.input.motionprediction {
+
+  public interface MotionEventPredictor {
+    method public void dispose();
+    method public static androidx.input.motionprediction.MotionEventPredictor newInstance(android.view.View);
+    method public android.view.MotionEvent? predict();
+    method public void recordMovement(android.view.MotionEvent);
+  }
+
+}
+
diff --git a/input/input-motionprediction/api/public_plus_experimental_current.txt b/input/input-motionprediction/api/public_plus_experimental_current.txt
new file mode 100644
index 0000000..a2cdf99
--- /dev/null
+++ b/input/input-motionprediction/api/public_plus_experimental_current.txt
@@ -0,0 +1,12 @@
+// Signature format: 4.0
+package androidx.input.motionprediction {
+
+  public interface MotionEventPredictor {
+    method public void dispose();
+    method public static androidx.input.motionprediction.MotionEventPredictor newInstance(android.view.View);
+    method public android.view.MotionEvent? predict();
+    method public void recordMovement(android.view.MotionEvent);
+  }
+
+}
+
diff --git a/input/input-motionprediction/api/res-current.txt b/input/input-motionprediction/api/res-current.txt
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/input/input-motionprediction/api/res-current.txt
diff --git a/input/input-motionprediction/api/restricted_current.txt b/input/input-motionprediction/api/restricted_current.txt
new file mode 100644
index 0000000..a2cdf99
--- /dev/null
+++ b/input/input-motionprediction/api/restricted_current.txt
@@ -0,0 +1,12 @@
+// Signature format: 4.0
+package androidx.input.motionprediction {
+
+  public interface MotionEventPredictor {
+    method public void dispose();
+    method public static androidx.input.motionprediction.MotionEventPredictor newInstance(android.view.View);
+    method public android.view.MotionEvent? predict();
+    method public void recordMovement(android.view.MotionEvent);
+  }
+
+}
+
diff --git a/input/input-motionprediction/build.gradle b/input/input-motionprediction/build.gradle
new file mode 100644
index 0000000..9475cdb
--- /dev/null
+++ b/input/input-motionprediction/build.gradle
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import androidx.build.LibraryType
+
+plugins {
+    id("AndroidXPlugin")
+    id("com.android.library")
+}
+
+dependencies {
+    api("androidx.annotation:annotation:1.2.0")
+
+    androidTestImplementation(libs.testExtJunit)
+    androidTestImplementation(libs.testCore)
+    androidTestImplementation(libs.testRunner)
+    androidTestImplementation(libs.testRules)
+    androidTestImplementation(libs.espressoCore, excludes.espresso)
+}
+
+android {
+    defaultConfig {
+        minSdkVersion 16
+    }
+    namespace "androidx.input.motionprediction"
+}
+
+androidx {
+    name = "Android Motion Prediction"
+    type = LibraryType.PUBLISHED_LIBRARY
+    mavenVersion = LibraryVersions.INPUT_MOTIONPREDICTION
+    mavenGroup = LibraryGroups.INPUT
+    inceptionYear = "2022"
+    description = "reduce the latency of input interactions by predicting future MotionEvents"
+}
diff --git a/input/input-motionprediction/src/main/AndroidManifest.xml b/input/input-motionprediction/src/main/AndroidManifest.xml
new file mode 100644
index 0000000..e4e6dc1f
--- /dev/null
+++ b/input/input-motionprediction/src/main/AndroidManifest.xml
@@ -0,0 +1,17 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright 2022 The Android Open Source Project
+
+     Licensed under the Apache License, Version 2.0 (the "License");
+     you may not use this file except in compliance with the License.
+     You may obtain a copy of the License at
+
+          http://www.apache.org/licenses/LICENSE-2.0
+
+     Unless required by applicable law or agreed to in writing, software
+     distributed under the License is distributed on an "AS IS" BASIS,
+     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+     See the License for the specific language governing permissions and
+     limitations under the License.
+-->
+
+<manifest />
\ No newline at end of file
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/MotionEventPredictor.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/MotionEventPredictor.java
new file mode 100644
index 0000000..c27d293
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/MotionEventPredictor.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction;
+
+import android.view.MotionEvent;
+import android.view.View;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import androidx.input.motionprediction.kalman.KalmanMotionEventPredictor;
+
+/**
+ * There is a gap between the time a user touches the screen and that information is reported to the
+ * app; a motion predictor is a utility that provides predicted {@link android.view.MotionEvent}
+ * based on the previously received ones. Obtain a new predictor instance using
+ * {@link #newInstance(android.view.View)}; put the motion events you receive into it with
+ * {@link #recordMovement(android.view.MotionEvent)}, and call {@link #predict()} to retrieve the
+ * predicted  {@link android.view.MotionEvent} that would occur at the moment the next frame is
+ * rendered on the display. Once no more predictions are needed, call {@link #dispose()} to stop it
+ * and clean up resources.
+ */
+public interface MotionEventPredictor {
+    /**
+     * Record a user's movement to the predictor. You should call this for every
+     * {@link android.view.MotionEvent} that is received by the associated
+     * {@link android.view.View}.
+     * @param event the {@link android.view.MotionEvent} the associated view received and that
+     *              needs to be recorded.
+     */
+    void recordMovement(@NonNull MotionEvent event);
+
+    /**
+     * Compute a prediction
+     * @return the predicted {@link android.view.MotionEvent}, or null if not possible to make a
+     * prediction.
+     */
+    @Nullable
+    MotionEvent predict();
+
+    /**
+     * Notify the predictor that no more predictions are needed. Any subsequent call to
+     * {@link #predict()} will return null.
+     */
+    void dispose();
+
+    /**
+     * Create a new motion predictor associated to a specific {@link android.view.View}
+     * @param view the view to associated to this predictor
+     * @return the new predictor instance
+     */
+    static @NonNull MotionEventPredictor newInstance(@NonNull View view) {
+        return new KalmanMotionEventPredictor();
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/BatchedMotionEvent.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/BatchedMotionEvent.java
new file mode 100644
index 0000000..6392b40
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/BatchedMotionEvent.java
@@ -0,0 +1,118 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import android.view.MotionEvent;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.RestrictTo;
+
+import java.util.Iterator;
+
+/**
+ * This class contains a list of historical {@link MotionEvent.PointerCoords} for a given time
+ *
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class BatchedMotionEvent {
+    /**
+     * Historical pointer coordinate data as per {@link MotionEvent#getPointerCoords}, that occurred
+     * between this event and the previous event for the given pointer. Only applies to ACTION_MOVE
+     * events.
+     */
+    public final MotionEvent.PointerCoords[] coords;
+    /**
+     * The time this event occurred in the {@link android.os.SystemClock#uptimeMillis} time base.
+     */
+    public long timeMs;
+
+    public BatchedMotionEvent(int pointerCount) {
+        coords = new MotionEvent.PointerCoords[pointerCount];
+        for (int i = 0; i < pointerCount; ++i) {
+            coords[i] = new MotionEvent.PointerCoords();
+        }
+    }
+
+    /**
+     * This method creates an {@link Iterable} that will iterate over the historical {@link
+     * MotionEvent}s.
+     */
+    public static @NonNull IterableMotionEvent iterate(@NonNull MotionEvent ev) {
+        return new IterableMotionEvent(ev);
+    }
+
+    /** An {@link Iterable} list of {@link BatchedMotionEvent} objects. */
+    public static class IterableMotionEvent implements Iterable<BatchedMotionEvent> {
+        private final int mPointerCount;
+        private final MotionEvent mMotionEvent;
+
+        IterableMotionEvent(@NonNull MotionEvent motionEvent) {
+            mMotionEvent = motionEvent;
+            mPointerCount = motionEvent.getPointerCount();
+        }
+
+        public @NonNull MotionEvent getMotionEvent() {
+            return mMotionEvent;
+        }
+
+        public @NonNull int getPointerCount() {
+            return mPointerCount;
+        }
+
+        @Override
+        @NonNull
+        public Iterator<BatchedMotionEvent> iterator() {
+            return new Iterator<BatchedMotionEvent>() {
+                private int mHistoryId = 0;
+
+                @Override
+                public boolean hasNext() {
+                    return mHistoryId < (getMotionEvent().getHistorySize() + 1);
+                }
+
+                @Override
+                public BatchedMotionEvent next() {
+                    MotionEvent motionEvent = getMotionEvent();
+                    int pointerCount = getPointerCount();
+
+                    if (mHistoryId > motionEvent.getHistorySize()) {
+                        return null;
+                    }
+                    BatchedMotionEvent batchedEvent = new BatchedMotionEvent(pointerCount);
+                    if (mHistoryId < motionEvent.getHistorySize()) {
+                        for (int pointerId = 0; pointerId < pointerCount; ++pointerId) {
+                            motionEvent.getHistoricalPointerCoords(
+                                    pointerId, mHistoryId, batchedEvent.coords[pointerId]);
+                        }
+                        batchedEvent.timeMs = motionEvent.getHistoricalEventTime(mHistoryId);
+                    } else { // (mHistoryId == mMotionEvent.getHistorySize()) {
+                        for (int pointerId = 0; pointerId < pointerCount; ++pointerId) {
+                            motionEvent.getPointerCoords(
+                                    pointerId, batchedEvent.coords[pointerId]);
+                        }
+                        batchedEvent.timeMs = motionEvent.getEventTime();
+                    }
+                    mHistoryId++;
+                    return batchedEvent;
+                }
+            };
+        }
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/InkPredictor.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/InkPredictor.java
new file mode 100644
index 0000000..98f4a44
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/InkPredictor.java
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import android.view.MotionEvent;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import androidx.annotation.RestrictTo;
+
+/**
+ * Simple interface for predicting ink points.
+ *
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public interface InkPredictor {
+
+    /** Gets the current prediction target */
+    int getPredictionTarget();
+
+    /** Sets the current prediction target */
+    void setPredictionTarget(int predictionTargetMillis);
+
+    /** Sets the report rate */
+    void setReportRate(int reportRateMs);
+
+    /** Reports the motion events */
+    boolean onTouchEvent(@NonNull MotionEvent event);
+
+    /** @return null if not possible to make a prediction. */
+    @Nullable
+    MotionEvent predict();
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanFilter.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanFilter.java
new file mode 100644
index 0000000..cfcec00
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanFilter.java
@@ -0,0 +1,94 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.RestrictTo;
+import androidx.input.motionprediction.kalman.matrix.Matrix;
+
+/**
+ * Kalman filter implementation following http://filterpy.readthedocs.io/en/latest/
+ *
+ * <p>To keep a reasonable naming scheme we are not following android naming conventions in this
+ * class.
+ *
+ * <p>To improve performance, this filter is specialized to use a 4 dimensional state, with single
+ * dimension measurements.
+ *
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class KalmanFilter {
+    // State estimate
+    public @NonNull Matrix x;
+
+    // State estimate covariance
+    public @NonNull Matrix P;
+
+    // Process noise
+    public @NonNull Matrix Q;
+
+    // Measurement noise (mZDim, mZDim)
+    public @NonNull Matrix R;
+
+    // State transition matrix
+    public @NonNull Matrix F;
+
+    // Measurement matrix
+    public @NonNull Matrix H;
+
+    // Kalman gain
+    public @NonNull Matrix K;
+
+    public KalmanFilter(int xDim, int zDim) {
+        x = new Matrix(xDim, 1);
+        P = Matrix.identity(xDim);
+        Q = Matrix.identity(xDim);
+        R = Matrix.identity(zDim);
+        F = new Matrix(xDim, xDim);
+        H = new Matrix(zDim, xDim);
+        K = new Matrix(xDim, zDim);
+    }
+
+    /** Resets the internal state of this Kalman filter. */
+    public void reset() {
+        // NOTE: It is not necessary to reset Q, R, F, and H matrices.
+        x.fill(0);
+        Matrix.setIdentity(P);
+        K.fill(0);
+    }
+
+    /**
+     * Performs the prediction phase of the filter, using the state estimate to produce a new
+     * estimate for the current timestep.
+     */
+    public void predict() {
+        x = F.dot(x);
+        P = F.dot(P).dotTranspose(F).plus(Q);
+    }
+
+    /** Updates the state estimate to incorporate the new observation z. */
+    public void update(@NonNull Matrix z) {
+        Matrix y = z.minus(H.dot(x));
+        Matrix tS = H.dot(P).dotTranspose(H).plus(R);
+        K = P.dotTranspose(H).dot(tS.inverse());
+        x = x.plus(K.dot(y));
+        P = P.minus(K.dot(H).dot(P));
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanInkPredictor.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanInkPredictor.java
new file mode 100644
index 0000000..73c1f4a
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanInkPredictor.java
@@ -0,0 +1,326 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import android.util.Log;
+import android.view.MotionEvent;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import androidx.annotation.RestrictTo;
+import androidx.input.motionprediction.kalman.matrix.DVector2;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Locale;
+
+/**
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class KalmanInkPredictor implements InkPredictor {
+    private static final String TAG = "KalmanInkPredictor";
+
+    // Influence of jank during each prediction sample
+    private static final float JANK_INFLUENCE = 0.1f;
+
+    // Influence of acceleration during each prediction sample
+    private static final float ACCELERATION_INFLUENCE = 0.5f;
+
+    // Influence of velocity during each prediction sample
+    private static final float VELOCITY_INFLUENCE = 1.0f;
+
+    // Range of jank values to expect.
+    // Low value will use maximum prediction, high value will use no prediction.
+    private static final float LOW_JANK = 0.02f;
+    private static final float HIGH_JANK = 0.2f;
+
+    // Range of pen speed to expect (in dp / ms).
+    // Low value will not use prediction, high value will use full prediction.
+    private static final float LOW_SPEED = 0.0f;
+    private static final float HIGH_SPEED = 2.0f;
+
+    private static final int EVENT_TIME_IGNORED_THRESHOLD_MS = 20;
+
+    // Minimum number of Kalman filter samples needed for predicting the next point
+    private static final int MIN_KALMAN_FILTER_ITERATIONS = 4;
+
+    // Target time in milliseconds to predict.
+    private float mPredictionTargetMs = 0.0f;
+
+    // The Kalman filter is tuned to smooth noise while maintaining fast reaction to direction
+    // changes. The stronger the filter, the smoother the prediction result will be, at the
+    // cost of possible prediction errors.
+    private final PenKalmanFilter mKalman = new PenKalmanFilter(0.01, 1.0);
+
+    private final DVector2 mLastPosition = new DVector2();
+    private long mPrevEventTime;
+    private List<Float> mReportRates = new LinkedList<>();
+    private int mExpectedPredictionSampleSize = -1;
+    private float mReportRateMs = 0;
+
+    private final DVector2 mPosition = new DVector2();
+    private final DVector2 mVelocity = new DVector2();
+    private final DVector2 mAcceleration = new DVector2();
+    private final DVector2 mJank = new DVector2();
+
+    /* pointer of the gesture that require prediction */
+    private int mPointerId = 0;
+
+    private double mPressure = 0;
+
+    /**
+     * Kalman based ink predictor, predicting the location of the pen `predictionTarget`
+     * milliseconds into the future.
+     *
+     * <p>This filter can provide solid prediction up to 25ms into the future. If you are not
+     * achieving close-to-zero latency, prediction errors can be more visible and the target should
+     * be reduced to 20ms.
+     */
+    public KalmanInkPredictor() {
+        mKalman.reset();
+        mPrevEventTime = 0;
+    }
+
+    void initStrokePrediction(int pointerId) {
+        mKalman.reset();
+        mPrevEventTime = 0;
+        mPointerId = pointerId;
+    }
+
+    private void update(float x, float y, float pressure, long eventTime) {
+        if (x == mLastPosition.a1
+                && y == mLastPosition.a2
+                && (eventTime <= (mPrevEventTime + EVENT_TIME_IGNORED_THRESHOLD_MS))) {
+            // Reduce Kalman filter jank by ignoring input event with similar coordinates
+            // and eventTime as previous input event.
+            // This is particularly useful when multiple pointer are on screen as in this case the
+            // application will receive simultaneously multiple ACTION_MOVE MotionEvent
+            // where position on screen and eventTime is unchanged.
+            // This behavior that happens only in ARC++ and is likely due to Chrome Aura
+            // implementation.
+            return;
+        }
+
+        mKalman.update(x, y, pressure);
+        mLastPosition.a1 = x;
+        mLastPosition.a2 = y;
+
+        // Calculate average report rate over the first 20 samples. Most sensors will not
+        // provide reliable timestamps and do not report at an even interval, so this is just
+        // to be used as an estimate.
+        if (mReportRates != null && mReportRates.size() < 20) {
+            if (mPrevEventTime > 0) {
+                float dt = eventTime - mPrevEventTime;
+                mReportRates.add(dt);
+                float sum = 0;
+                for (float rate : mReportRates) {
+                    sum += rate;
+                }
+                mReportRateMs = sum / mReportRates.size();
+            }
+        }
+        mPrevEventTime = eventTime;
+    }
+
+    @Override
+    public int getPredictionTarget() {
+        // Prediction target should always be an int, so no precision lost in the cast
+        return (int) mPredictionTargetMs;
+    }
+
+    @Override
+    public void setPredictionTarget(int predictionTargetMillis) {
+        if (predictionTargetMillis < 0) {
+            predictionTargetMillis = 0;
+        }
+        mPredictionTargetMs = predictionTargetMillis;
+        if (mReportRates == null) {
+            mExpectedPredictionSampleSize = (int) Math.ceil(mPredictionTargetMs / mReportRateMs);
+        }
+    }
+
+    @Override
+    public void setReportRate(int reportRateMs) {
+        if (reportRateMs <= 0) {
+            throw new IllegalArgumentException(
+                    "reportRateMs should always be a strictly" + "positive number");
+        }
+        mReportRateMs = reportRateMs;
+        mReportRates = null;
+
+        mExpectedPredictionSampleSize = (int) Math.ceil(mPredictionTargetMs / mReportRateMs);
+    }
+
+    @Override
+    public boolean onTouchEvent(@NonNull MotionEvent event) {
+        if (event.getActionMasked() == MotionEvent.ACTION_CANCEL) {
+            mKalman.reset();
+            mPrevEventTime = 0;
+            return false;
+        }
+        int pointerIndex = event.findPointerIndex(mPointerId);
+        if (pointerIndex == -1) {
+            Log.i(
+                    TAG,
+                    String.format(
+                            Locale.ROOT,
+                            "onTouchEvent: Cannot find pointerId=%d in motionEvent=%s",
+                            mPointerId,
+                            event));
+            return false;
+        }
+        for (BatchedMotionEvent ev : BatchedMotionEvent.iterate(event)) {
+            MotionEvent.PointerCoords pointerCoords = ev.coords[pointerIndex];
+            update(pointerCoords.x, pointerCoords.y, pointerCoords.pressure, ev.timeMs);
+        }
+        return true;
+    }
+
+    @Override
+    public @Nullable MotionEvent predict() {
+        if (mExpectedPredictionSampleSize == -1
+                && mKalman.getNumIterations() < MIN_KALMAN_FILTER_ITERATIONS) {
+            return null;
+        }
+
+        mPosition.set(mLastPosition);
+        mVelocity.set(mKalman.getVelocity());
+        mAcceleration.set(mKalman.getAcceleration());
+        mJank.set(mKalman.getJank());
+
+        mPressure = mKalman.getPressure();
+        double pressureChange = mKalman.getPressureChange();
+
+        // Adjust prediction distance based on confidence of mKalman filter as well as movement
+        // speed.
+        double speedAbs = mVelocity.magnitude() / mReportRateMs;
+        double speedFactor = normalizeRange(speedAbs, LOW_SPEED, HIGH_SPEED);
+        double jankAbs = mJank.magnitude();
+        double jankFactor = 1.0 - normalizeRange(jankAbs, LOW_JANK, HIGH_JANK);
+        double confidenceFactor = speedFactor * jankFactor;
+
+        MotionEvent predictedEvent = null;
+        final MotionEvent.PointerProperties[] pointerProperties =
+                new MotionEvent.PointerProperties[1];
+        pointerProperties[0] = new MotionEvent.PointerProperties();
+        pointerProperties[0].id = mPointerId;
+
+        // Project physical state of the pen into the future.
+        int predictionTargetInSamples =
+                (int) Math.ceil(mPredictionTargetMs / mReportRateMs * confidenceFactor);
+
+        // Normally this should always be false as confidenceFactor should be less than 1.0
+        if (mExpectedPredictionSampleSize != -1
+                && predictionTargetInSamples > mExpectedPredictionSampleSize) {
+            predictionTargetInSamples = mExpectedPredictionSampleSize;
+        }
+
+        int i = 0;
+        for (; i < predictionTargetInSamples; i++) {
+            mAcceleration.a1 += mJank.a1 * JANK_INFLUENCE;
+            mAcceleration.a2 += mJank.a2 * JANK_INFLUENCE;
+            mVelocity.a1 += mAcceleration.a1 * ACCELERATION_INFLUENCE;
+            mVelocity.a2 += mAcceleration.a2 * ACCELERATION_INFLUENCE;
+            mPosition.a1 += mVelocity.a1 * VELOCITY_INFLUENCE;
+            mPosition.a2 += mVelocity.a2 * VELOCITY_INFLUENCE;
+            mPressure += pressureChange;
+
+            // Abort prediction if the pen is to be lifted.
+            if (mPressure < 0.1) {
+                //TODO: Should we generate ACTION_UP MotionEvent instead of ACTION_MOVE?
+                break;
+            }
+            mPressure = Math.min(mPressure, 1.0f);
+
+            MotionEvent.PointerCoords[] coords = {new MotionEvent.PointerCoords()};
+            coords[0].x = (float) mPosition.a1;
+            coords[0].y = (float) mPosition.a2;
+            coords[0].pressure = (float) mPressure;
+            if (predictedEvent == null) {
+                predictedEvent =
+                        MotionEvent.obtain(
+                                0 /* downTime */,
+                                0 /* eventTime */,
+                                MotionEvent.ACTION_MOVE /* action */,
+                                1 /* pointerCount */,
+                                pointerProperties /* pointer properties */,
+                                coords /* pointerCoords */,
+                                0 /* metaState */,
+                                0 /* button state */,
+                                1.0f /* xPrecision */,
+                                1.0f /* yPrecision */,
+                                0 /* deviceId */,
+                                0 /* edgeFlags */,
+                                0 /* source */,
+                                0 /* flags */);
+            } else {
+                predictedEvent.addBatch(0, coords, 0);
+            }
+        }
+
+        return predictedEvent;
+    }
+
+    private double normalizeRange(double x, double min, double max) {
+        double normalized = (x - min) / (max - min);
+        return Math.min(1.0, Math.max(normalized, 0.0));
+    }
+
+    /**
+     * Append predicted event with samples where position and pressure are constant if predictor
+     * consumer expect more samples
+     *
+     * @param predictedEvent
+     */
+    protected @Nullable MotionEvent appendPredictedEvent(@Nullable MotionEvent predictedEvent) {
+        int predictedEventSize = (predictedEvent == null) ? 0 : predictedEvent.getHistorySize();
+        for (int i = predictedEventSize; i < mExpectedPredictionSampleSize; i++) {
+            MotionEvent.PointerCoords[] coords = {new MotionEvent.PointerCoords()};
+            coords[0].x = (float) mPosition.a1;
+            coords[0].y = (float) mPosition.a2;
+            coords[0].pressure = (float) mPressure;
+            if (predictedEvent == null) {
+                final MotionEvent.PointerProperties[] pointerProperties =
+                        new MotionEvent.PointerProperties[1];
+                pointerProperties[0] = new MotionEvent.PointerProperties();
+                pointerProperties[0].id = mPointerId;
+                predictedEvent =
+                        MotionEvent.obtain(
+                                0 /* downTime */,
+                                0 /* eventTime */,
+                                MotionEvent.ACTION_MOVE /* action */,
+                                1 /* pointerCount */,
+                                pointerProperties /* pointer properties */,
+                                coords /* pointerCoords */,
+                                0 /* metaState */,
+                                0 /* buttonState */,
+                                1.0f /* xPrecision */,
+                                1.0f /* yPrecision */,
+                                0 /* deviceId */,
+                                0 /* edgeFlags */,
+                                0 /* source */,
+                                0 /* flags */);
+            } else {
+                predictedEvent.addBatch(0, coords, 0);
+            }
+        }
+        return predictedEvent;
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanMotionEventPredictor.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanMotionEventPredictor.java
new file mode 100644
index 0000000..e5abbd9
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanMotionEventPredictor.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import android.view.MotionEvent;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import androidx.annotation.RestrictTo;
+import androidx.input.motionprediction.MotionEventPredictor;
+
+/**
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class KalmanMotionEventPredictor implements MotionEventPredictor {
+    private MultiPointerPredictor mMultiPointerPredictor;
+    private boolean mDisposed = false;
+
+    public KalmanMotionEventPredictor() {
+        mMultiPointerPredictor = new MultiPointerPredictor();
+        // 1 may seem arbitrary, but this basically tells the predictor to
+        // just predict the next MotionEvent.
+        // This will need to change as we want to build a prediction depending
+        // on the expected time that the frame will arrive to the screen.
+        mMultiPointerPredictor.setPredictionTarget(1);
+    }
+
+    @Override
+    public void recordMovement(@NonNull MotionEvent event) {
+        mMultiPointerPredictor.onTouchEvent(event);
+    }
+
+    @Nullable
+    @Override
+    public MotionEvent predict() {
+        if (mDisposed) {
+            return null;
+        }
+        return mMultiPointerPredictor.predict();
+    }
+
+    @Override
+    public void dispose() {
+        mDisposed = true;
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/MultiPointerPredictor.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/MultiPointerPredictor.java
new file mode 100644
index 0000000..d223ddb
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/MultiPointerPredictor.java
@@ -0,0 +1,221 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import android.util.Log;
+import android.util.SparseArray;
+import android.view.MotionEvent;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import androidx.annotation.RestrictTo;
+
+import java.util.Locale;
+
+/**
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class MultiPointerPredictor implements InkPredictor {
+    private static final String TAG = "MultiPointerPredictor";
+    private static final boolean DEBUG_PREDICTION = Log.isLoggable(TAG, Log.DEBUG);
+
+    private SparseArray<KalmanInkPredictor> mPredictorMap = new SparseArray<>();
+    private int mPredictionTargetMs = 0;
+    private int mReportRateMs = 0;
+
+    public MultiPointerPredictor() {}
+
+    @Override
+    public int getPredictionTarget() {
+        return mPredictionTargetMs;
+    }
+
+    @Override
+    public void setPredictionTarget(int predictionTargetMillis) {
+        if (predictionTargetMillis < 0) {
+            predictionTargetMillis = 0;
+        }
+        mPredictionTargetMs = predictionTargetMillis;
+
+        for (int i = 0; i < mPredictorMap.size(); ++i) {
+            mPredictorMap.valueAt(i).setPredictionTarget(predictionTargetMillis);
+        }
+    }
+
+    @Override
+    public void setReportRate(int reportRateMs) {
+        if (reportRateMs <= 0) {
+            throw new IllegalArgumentException(
+                    "reportRateMs should always be a strictly" + "positive number");
+        }
+        mReportRateMs = reportRateMs;
+
+        for (int i = 0; i < mPredictorMap.size(); ++i) {
+            mPredictorMap.valueAt(i).setReportRate(mReportRateMs);
+        }
+    }
+
+    @Override
+    public boolean onTouchEvent(@NonNull MotionEvent event) {
+        int action = event.getActionMasked();
+        int pointerId = event.getPointerId(event.getActionIndex());
+        if (action == MotionEvent.ACTION_DOWN || action == MotionEvent.ACTION_POINTER_DOWN) {
+            KalmanInkPredictor predictor = new KalmanInkPredictor();
+            predictor.setPredictionTarget(mPredictionTargetMs);
+            predictor.setReportRate(mReportRateMs);
+            predictor.initStrokePrediction(pointerId);
+            predictor.onTouchEvent(event);
+            mPredictorMap.put(pointerId, predictor);
+        } else if (action == MotionEvent.ACTION_UP) {
+            KalmanInkPredictor predictor = mPredictorMap.get(pointerId);
+            if (predictor != null) {
+                mPredictorMap.remove(pointerId);
+                predictor.onTouchEvent(event);
+            }
+            mPredictorMap.clear();
+        } else if (action == MotionEvent.ACTION_POINTER_UP) {
+            KalmanInkPredictor predictor = mPredictorMap.get(pointerId);
+            if (predictor != null) {
+                mPredictorMap.remove(pointerId);
+                predictor.onTouchEvent(event);
+            }
+        } else if (action == MotionEvent.ACTION_CANCEL) {
+            mPredictorMap.clear();
+        } else if (action == MotionEvent.ACTION_MOVE) {
+            for (int i = 0; i < mPredictorMap.size(); ++i) {
+                mPredictorMap.valueAt(i).onTouchEvent(event);
+            }
+        } else {
+            // ignore other events
+            return false;
+        }
+        return true;
+    }
+
+    /** Support eventTime */
+    @Override
+    public @Nullable MotionEvent predict() {
+        final int pointerCount = mPredictorMap.size();
+        // Shortcut for likely case where only zero or one pointer is on the screen
+        // this logic exists only to make sure logic when one pointer is on screen then
+        // there is no performance degradation of using MultiPointerPredictor vs KalmanInkPredictor
+        // TODO: verify performance is not degraded by removing this shortcut logic.
+        if (pointerCount == 0) {
+            if (DEBUG_PREDICTION) {
+                Log.d(TAG, "predict() -> null: no pointer on screen");
+            }
+            return null;
+        }
+        if (pointerCount == 1) {
+            KalmanInkPredictor predictor = mPredictorMap.valueAt(0);
+            MotionEvent predictedEv = predictor.predict();
+            if (DEBUG_PREDICTION) {
+                Log.d(TAG, "predict() -> MotionEvent: " + predictedEv);
+            }
+            return predictedEv;
+        }
+
+        // Predict MotionEvent for each pointer
+        int[] pointerIds = new int[pointerCount];
+        MotionEvent[] singlePointerEvents = new MotionEvent[pointerCount];
+        for (int i = 0; i < pointerCount; ++i) {
+            pointerIds[i] = mPredictorMap.keyAt(i);
+            KalmanInkPredictor predictor = mPredictorMap.valueAt(i);
+            singlePointerEvents[i] = predictor.predict();
+            // If predictor consumer expect more sample, generate sample where position and
+            // pressure are constant
+            singlePointerEvents[i] = predictor.appendPredictedEvent(singlePointerEvents[i]);
+        }
+
+        // Compute minimal history size for every predicted single pointer MotionEvent
+        int minHistorySize = Integer.MAX_VALUE;
+        for (MotionEvent ev : singlePointerEvents) {
+            if (ev.getHistorySize() < minHistorySize) {
+                minHistorySize = ev.getHistorySize();
+            }
+        }
+        // Take into account the current event of each predicted MotionEvent
+        minHistorySize += 1;
+
+        // Merge single pointer MotionEvent into a single MotionEvent
+        MotionEvent.PointerCoords[][] pointerCoords =
+                new MotionEvent.PointerCoords[minHistorySize][pointerCount];
+        for (int pointerIndex = 0; pointerIndex < pointerCount; pointerIndex++) {
+            int historyIndex = 0;
+            for (BatchedMotionEvent ev :
+                    BatchedMotionEvent.iterate(singlePointerEvents[pointerIndex])) {
+                pointerCoords[historyIndex][pointerIndex] = ev.coords[0];
+                if (minHistorySize <= ++historyIndex) {
+                    break;
+                }
+            }
+        }
+
+        // Recycle single pointer predicted MotionEvent
+        for (MotionEvent ev : singlePointerEvents) {
+            ev.recycle();
+        }
+
+        // Generate predicted multi-pointer MotionEvent
+        final MotionEvent.PointerProperties[] pointerProperties =
+                new MotionEvent.PointerProperties[pointerCount];
+        for (int i = 0; i < pointerCount; i++) {
+            pointerProperties[i] = new MotionEvent.PointerProperties();
+            pointerProperties[i].id = pointerIds[i];
+        }
+        MotionEvent multiPointerEvent =
+                MotionEvent.obtain(
+                        0 /* down time */,
+                        0 /* event time */,
+                        MotionEvent.ACTION_MOVE /* action */,
+                        pointerCount /* pointer count */,
+                        pointerProperties /* pointer properties */,
+                        pointerCoords[0] /* pointer coordinates */,
+                        0 /* meta state */,
+                        0 /* button state */,
+                        1.0f /* x */,
+                        1.0f /* y */,
+                        0 /* device ID */,
+                        0 /* edge flags */,
+                        0 /* source */,
+                        0 /* flags */);
+        for (int historyIndex = 1; historyIndex < minHistorySize; historyIndex++) {
+            multiPointerEvent.addBatch(0, pointerCoords[historyIndex], 0);
+        }
+        if (DEBUG_PREDICTION) {
+            final StringBuilder builder =
+                    new StringBuilder(
+                            String.format(
+                                    Locale.ROOT,
+                                    "predict() -> MotionEvent: (pointerCount=%d, historySize=%d);",
+                                    multiPointerEvent.getPointerCount(),
+                                    multiPointerEvent.getHistorySize()));
+            for (BatchedMotionEvent motionEvent : BatchedMotionEvent.iterate(multiPointerEvent)) {
+                builder.append("      ");
+                for (MotionEvent.PointerCoords coord : motionEvent.coords) {
+                    builder.append(String.format(Locale.ROOT, "(%f, %f)", coord.x, coord.y));
+                }
+                builder.append("\n");
+            }
+            Log.d(TAG, builder.toString());
+        }
+        return multiPointerEvent;
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/PenKalmanFilter.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/PenKalmanFilter.java
new file mode 100644
index 0000000..a70b21e
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/PenKalmanFilter.java
@@ -0,0 +1,170 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.RestrictTo;
+import androidx.input.motionprediction.kalman.matrix.DVector2;
+import androidx.input.motionprediction.kalman.matrix.Matrix;
+
+/**
+ * Class that independently applies the Kalman Filter to each axis of the pen.
+ *
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class PenKalmanFilter {
+    private KalmanFilter mXKalman;
+    private KalmanFilter mYKalman;
+    private KalmanFilter mPKalman;
+
+    private DVector2 mPosition = new DVector2();
+    private DVector2 mVelocity = new DVector2();
+    private DVector2 mAcceleration = new DVector2();
+    private DVector2 mJank = new DVector2();
+    private double mPressure = 0;
+    private double mPressureChange = 0;
+
+    private double mSigmaProcess;
+    private double mSigmaMeasurement;
+
+    private int mNumIterations = 0;
+
+    private Matrix mNewX = new Matrix(1, 1);
+    private Matrix mNewY = new Matrix(1, 1);
+    private Matrix mNewP = new Matrix(1, 1);
+
+    /**
+     * @param sigmaProcess lower value = more filtering
+     * @param sigmaMeasurement higher value = more filtering
+     */
+    public PenKalmanFilter(double sigmaProcess, double sigmaMeasurement) {
+        mSigmaProcess = sigmaProcess;
+        mSigmaMeasurement = sigmaMeasurement;
+        mXKalman = createAxisKalmanFilter();
+        mYKalman = createAxisKalmanFilter();
+        mPKalman = createAxisKalmanFilter();
+    }
+
+    /** Reset filter into a neutral state. */
+    public void reset() {
+        mXKalman.reset();
+        mYKalman.reset();
+        mPKalman.reset();
+        mNumIterations = 0;
+    }
+
+    /**
+     * Update internal model of pen with new measurement. The state of the model can be obtained by
+     * the getPosition, getVelocity, etc methods.
+     */
+    public void update(float x, float y, float pressure) {
+        if (mNumIterations == 0) {
+            mXKalman.x.put(0, 0, x);
+            mYKalman.x.put(0, 0, y);
+            mPKalman.x.put(0, 0, pressure);
+        } else {
+            mNewX.put(0, 0, x);
+            mXKalman.predict();
+            mXKalman.update(mNewX);
+
+            mNewY.put(0, 0, y);
+            mYKalman.predict();
+            mYKalman.update(mNewY);
+
+            mNewP.put(0, 0, pressure);
+            mPKalman.predict();
+            mPKalman.update(mNewP);
+        }
+        mNumIterations += 1;
+
+        mPosition.a1 = mXKalman.x.get(0, 0);
+        mPosition.a2 = mYKalman.x.get(0, 0);
+        mVelocity.a1 = mXKalman.x.get(1, 0);
+        mVelocity.a2 = mYKalman.x.get(1, 0);
+        mAcceleration.a1 = mXKalman.x.get(2, 0);
+        mAcceleration.a2 = mYKalman.x.get(2, 0);
+        mJank.a1 = mXKalman.x.get(3, 0);
+        mJank.a2 = mYKalman.x.get(3, 0);
+        mPressure = mPKalman.x.get(0, 0);
+        mPressureChange = mPKalman.x.get(1, 0);
+    }
+
+    public @NonNull DVector2 getPosition() {
+        return mPosition;
+    }
+
+    public @NonNull DVector2 getVelocity() {
+        return mVelocity;
+    }
+
+    public @NonNull DVector2 getAcceleration() {
+        return mAcceleration;
+    }
+
+    public @NonNull DVector2 getJank() {
+        return mJank;
+    }
+
+    public double getPressure() {
+        return mPressure;
+    }
+
+    public double getPressureChange() {
+        return mPressureChange;
+    }
+
+    public int getNumIterations() {
+        return mNumIterations;
+    }
+
+    private KalmanFilter createAxisKalmanFilter() {
+        // We tune the filter with a normalized dt=1, then apply the actual report rate during
+        // prediction.
+        final double dt = 1.0;
+
+        final KalmanFilter kalman = new KalmanFilter(4, 1);
+
+        // State transition matrix is derived from basic physics:
+        // new_x = x + v * dt + 1/2 * a * dt^2 + 1/6 * jank * dt^3
+        // new_v = v + a * dt + 1/2 * jank * dt^2
+        // ...
+        kalman.F = new Matrix(4,
+                new double[]{
+                        1.0, dt, 0.5 * dt * dt, 0.16 * dt * dt * dt,
+                        0.0, 1.0, dt, 0.5 * dt * dt,
+                        0.0, 0.0, 1.0, dt,
+                        0, 0, 0, 1.0
+                });
+
+        // We model the system noise as a noisy force on the pen.
+        // The matrix G describes the impact of that noise on each state.
+        final Matrix g = new Matrix(1, new double[] {0.16 * dt * dt * dt, 0.5 * dt * dt, dt, 1});
+        g.dotTranspose(g, kalman.Q);
+        kalman.Q.scale(mSigmaProcess);
+
+        // Measurements only impact the location
+        kalman.H = new Matrix(4, new double[] {1.0, 0.0, 0.0, 0.0});
+
+        // Measurement noise is a 1-D normal distribution
+        kalman.R.put(0, 0, mSigmaMeasurement);
+
+        return kalman;
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/DVector2.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/DVector2.java
new file mode 100644
index 0000000..fffd454
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/DVector2.java
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman.matrix;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.RestrictTo;
+
+/**
+ * A 2 element fixed sized vector, where each element is a double. This class can represent a (2x1)
+ * or (1x2) matrix.
+ *
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class DVector2 {
+    public double a1;
+    public double a2;
+
+    public DVector2() {}
+
+    /** Returns the vector magnitude (abs, length). */
+    public double magnitude() {
+        return Math.hypot(a1, a2);
+    }
+
+    /** Sets the elements to the values from {@code newValue}. */
+    public void set(@NonNull DVector2 newValue) {
+        a1 = newValue.a1;
+        a2 = newValue.a2;
+    }
+}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/Matrix.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/Matrix.java
new file mode 100644
index 0000000..d41772e
--- /dev/null
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/Matrix.java
@@ -0,0 +1,458 @@
+/*
+ * Copyright 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.input.motionprediction.kalman.matrix;
+
+import static androidx.annotation.RestrictTo.Scope.LIBRARY;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.RestrictTo;
+
+import java.util.Arrays;
+import java.util.Locale;
+
+// Based on http://androidxref.com/9.0.0_r3/xref/frameworks/opt/net/wifi/service/java/com/android/server/wifi/util/Matrix.java
+/**
+ * Utility for basic Matrix calculations.
+ *
+ * @hide
+ */
+@RestrictTo(LIBRARY)
+public class Matrix {
+
+    private final int mRows;
+    private final int mCols;
+    private final double[] mMem;
+
+    /**
+     * Creates a new matrix, initialized to zeros.
+     *
+     * @param rows number of mRows
+     * @param cols number of columns
+     */
+    public Matrix(int rows, int cols) {
+        mRows = rows;
+        mCols = cols;
+        mMem = new double[rows * cols];
+    }
+
+    /**
+     * Creates a new matrix using the provided array of values
+     *
+     * <p>Values are in row-major order.
+     *
+     * @param stride the number of columns
+     * @param values the array of values
+     * @throws IllegalArgumentException if length of values array not a multiple of stride
+     */
+    public Matrix(int stride, @NonNull double[] values) {
+        mRows = (values.length + stride - 1) / stride;
+        mCols = stride;
+        mMem = values;
+        if (mMem.length != mRows * mCols) {
+            throw new IllegalArgumentException(
+                    String.format(
+                            Locale.ROOT,
+                            "Invalid number of elements in 'values' Expected:%d Actual:%d",
+                            mMem.length,
+                            (mRows & mCols)));
+        }
+    }
+
+    /**
+     * Creates a new matrix, and copies the contents from the given {@code src} matrix.
+     *
+     * @param src the matrix to copy from
+     */
+    public Matrix(@NonNull Matrix src) {
+        mRows = src.mRows;
+        mCols = src.mCols;
+        mMem = new double[mRows * mCols];
+        System.arraycopy(src.mMem, 0, mMem, 0, mMem.length);
+    }
+
+    /** Returns the number of rows in the matrix. */
+    public int getNumRows() {
+        return mRows;
+    }
+
+    /** Returns the number of columns in the matrix. */
+    public int getNumCols() {
+        return mCols;
+    }
+
+    /**
+     * Creates an identity matrix with the given {@code width}.
+     *
+     * @param width the height and width of the identity matrix
+     * @return newly created identity matrix
+     */
+    public static @NonNull Matrix identity(int width) {
+        final Matrix ret = new Matrix(width, width);
+        setIdentity(ret);
+        return ret;
+    }
+
+    /**
+     * Sets all the diagonal elements to one and everything else to zero. If this is a square
+     * matrix, then it will be an identity matrix.
+     *
+     * @param matrix the matrix to perform the operation
+     */
+    public static void setIdentity(@NonNull Matrix matrix) {
+        Arrays.fill(matrix.mMem, 0.);
+        final int width = matrix.mRows < matrix.mCols ? matrix.mRows : matrix.mCols;
+        for (int i = 0; i < width; i++) {
+            matrix.put(i, i, 1);
+        }
+    }
+
+    /**
+     * Gets the value from row i, column j.
+     *
+     * @param i row number
+     * @param j column number
+     * @return the value at at i,j
+     * @throws IndexOutOfBoundsException if an index is out of bounds
+     */
+    public double get(int i, int j) {
+        if (!(0 <= i && i < mRows && 0 <= j && j < mCols)) {
+            throw new IndexOutOfBoundsException(
+                    String.format(
+                            Locale.ROOT,
+                            "Invalid matrix index value. i:%d j:%d not available in %s",
+                            i,
+                            j,
+                            shortString()));
+        }
+        return mMem[i * mCols + j];
+    }
+
+    /**
+     * Store a value in row i, column j.
+     *
+     * @param i row number
+     * @param j column number
+     * @param v value to store at i,j
+     * @throws IndexOutOfBoundsException if an index is out of bounds
+     */
+    public void put(int i, int j, double v) {
+        if (!(0 <= i && i < mRows && 0 <= j && j < mCols)) {
+            throw new IndexOutOfBoundsException(
+                    String.format(
+                            Locale.ROOT,
+                            "Invalid matrix index value. i:%d j:%d not available in %s",
+                            i,
+                            j,
+                            shortString()));
+        }
+        mMem[i * mCols + j] = v;
+    }
+
+    /**
+     * Sets all the elements to {@code value}.
+     *
+     * @param value the value to fill the matrix
+     */
+    public void fill(double value) {
+        Arrays.fill(mMem, value);
+    }
+
+    /**
+     * Scales every element by {@code alpha}.
+     *
+     * @param alpha the amount each element is multiplied by
+     */
+    public void scale(double alpha) {
+        final int size = mRows * mCols;
+        for (int i = 0; i < size; ++i) {
+            mMem[i] *= alpha;
+        }
+    }
+
+    /**
+     * Adds all elements of this matrix with {@code that}.
+     *
+     * @param that the other matrix
+     * @return a newly created matrix representing the sum of this and that
+     * @throws IllegalArgumentException if the dimensions differ
+     */
+    public @NonNull Matrix plus(@NonNull Matrix that) {
+        if (!(mRows == that.mRows && mCols == that.mCols)) {
+            throw new IllegalArgumentException(
+                    String.format(
+                            Locale.ROOT,
+                            "The matrix dimensions are not the same. this:%s that:%s",
+                            shortString(),
+                            that.shortString()));
+        }
+        for (int i = 0; i < mMem.length; i++) {
+            mMem[i] = mMem[i] + that.mMem[i];
+        }
+        return this;
+    }
+
+    /**
+     * Calculates the difference this matrix and {@code that}.
+     *
+     * @param that the other matrix
+     * @return newly created matrix representing the difference of this and that
+     * @throws IllegalArgumentException if the dimensions differ
+     */
+    public @NonNull Matrix minus(@NonNull Matrix that) {
+        if (!(mRows == that.mRows && mCols == that.mCols)) {
+            throw new IllegalArgumentException(
+                    String.format(
+                            Locale.ROOT,
+                            "The matrix dimensions are not the same. this:%s that:%s",
+                            shortString(),
+                            that.shortString()));
+        }
+        for (int i = 0; i < mMem.length; i++) {
+            mMem[i] = mMem[i] - that.mMem[i];
+        }
+        return this;
+    }
+
+    /**
+     * Calculates the matrix product of this matrix and {@code that}.
+     *
+     * @param that the other matrix
+     * @return newly created matrix representing the matrix product of this and that
+     * @throws IllegalArgumentException if the dimensions differ
+     */
+    public @NonNull Matrix dot(@NonNull Matrix that) {
+        try {
+            return dot(that, new Matrix(mRows, that.mCols));
+        } catch (IllegalArgumentException e) {
+            throw new IllegalArgumentException(
+                    String.format(
+                            Locale.ROOT,
+                            "The matrices dimensions are not conformant for a dot matrix "
+                                    + "operation. this:%s that:%s",
+                            shortString(),
+                            that.shortString()));
+        }
+    }
+
+    /**
+     * Calculates the matrix product of this matrix and {@code that}.
+     *
+     * @param that the other matrix
+     * @param result matrix to hold the result
+     * @return result, filled with the matrix product
+     * @throws IllegalArgumentException if the dimensions differ
+     */
+    public @NonNull Matrix dot(@NonNull Matrix that, @NonNull Matrix result) {
+        if (!(mRows == result.mRows && mCols == that.mRows && that.mCols == result.mCols)) {
+            throw new IllegalArgumentException(
+                    String.format(
+                            Locale.ROOT,
+                            "The matrices dimensions are not conformant for a dot matrix "
+                                    + "operation. this:%s that:%s result:%s",
+                            shortString(),
+                            that.shortString(),
+                            result.shortString()));
+        }
+        for (int i = 0; i < mRows; i++) {
+            for (int j = 0; j < that.mCols; j++) {
+                double s = 0.0;
+                for (int k = 0; k < mCols; k++) {
+                    s += get(i, k) * that.get(k, j);
+                }
+                result.put(i, j, s);
+            }
+        }
+        return result;
+    }
+
+    /**
+     * Calculates the inverse of a square matrix
+     *
+     * @return newly created matrix representing the matrix inverse
+     * @throws ArithmeticException if the matrix is not invertible
+     */
+    public @NonNull Matrix inverse() {
+        if (!(mRows == mCols)) {
+            throw new IllegalArgumentException(
+                    String.format(Locale.ROOT, "The matrix is not square. this:%s", shortString()));
+        }
+        final Matrix scratch = new Matrix(mRows, 2 * mCols);
+
+        for (int i = 0; i < mRows; i++) {
+            for (int j = 0; j < mCols; j++) {
+                scratch.put(i, j, get(i, j));
+                scratch.put(i, mCols + j, i == j ? 1.0 : 0.0);
+            }
+        }
+
+        for (int i = 0; i < mRows; i++) {
+            int ibest = i;
+            double vbest = Math.abs(scratch.get(ibest, ibest));
+            for (int ii = i + 1; ii < mRows; ii++) {
+                double v = Math.abs(scratch.get(ii, i));
+                if (v > vbest) {
+                    ibest = ii;
+                    vbest = v;
+                }
+            }
+            if (ibest != i) {
+                for (int j = 0; j < scratch.mCols; j++) {
+                    double t = scratch.get(i, j);
+                    scratch.put(i, j, scratch.get(ibest, j));
+                    scratch.put(ibest, j, t);
+                }
+            }
+            double d = scratch.get(i, i);
+            if (d == 0.0) {
+                throw new ArithmeticException("Singular matrix");
+            }
+            for (int j = 0; j < scratch.mCols; j++) {
+                scratch.put(i, j, scratch.get(i, j) / d);
+            }
+            for (int ii = i + 1; ii < mRows; ii++) {
+                d = scratch.get(ii, i);
+                for (int j = 0; j < scratch.mCols; j++) {
+                    scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j));
+                }
+            }
+        }
+        for (int i = mRows - 1; i >= 0; i--) {
+            for (int ii = 0; ii < i; ii++) {
+                double d = scratch.get(ii, i);
+                for (int j = 0; j < scratch.mCols; j++) {
+                    scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j));
+                }
+            }
+        }
+        for (int i = 0; i < mRows; i++) {
+            for (int j = 0; j < mCols; j++) {
+                put(i, j, scratch.get(i, mCols + j));
+            }
+        }
+        return this;
+    }
+
+    /**
+     * Calculates the matrix product with the transpose of a second matrix.
+     *
+     * @param that the other matrix
+     * @return newly created matrix representing the matrix product of this and that.transpose()
+     * @throws IllegalArgumentException if shapes are not conformant
+     */
+    public @NonNull Matrix dotTranspose(@NonNull Matrix that) {
+        try {
+            return dotTranspose(that, new Matrix(mRows, that.mRows));
+        } catch (IllegalArgumentException e) {
+            throw new IllegalArgumentException(
+                    String.format(
+                            Locale.ROOT,
+                            "The matrices dimensions are not conformant for a transpose "
+                                    + "operation. this:%s that:%s",
+                            shortString(),
+                            that.shortString()));
+        }
+    }
+
+    /**
+     * Calculates the matrix product with the transpose of a second matrix.
+     *
+     * @param that the other matrix
+     * @param result space to hold the result
+     * @return result, filled with the matrix product of this and that.transpose()
+     * @throws IllegalArgumentException if shapes are not conformant
+     */
+    public @NonNull Matrix dotTranspose(@NonNull Matrix that, @NonNull Matrix result) {
+        if (!(mRows == result.mRows && mCols == that.mCols && that.mRows == result.mCols)) {
+            throw new IllegalArgumentException(
+                    String.format(
+                            Locale.ROOT,
+                            "The matrices dimensions are not conformant for a transpose "
+                                    + "operation. this:%s that:%s result:%s",
+                            shortString(),
+                            that.shortString(),
+                            result.shortString()));
+        }
+        for (int i = 0; i < mRows; i++) {
+            for (int j = 0; j < that.mRows; j++) {
+                double s = 0.0;
+                for (int k = 0; k < mCols; k++) {
+                    s += get(i, k) * that.get(j, k);
+                }
+                result.put(i, j, s);
+            }
+        }
+        return result;
+    }
+
+    /** Tests for equality. */
+    @Override
+    public boolean equals(Object that) {
+        if (this == that) {
+            return true;
+        }
+        if (!(that instanceof Matrix)) {
+            return false;
+        }
+        Matrix other = (Matrix) that;
+        if (mRows != other.mRows) {
+            return false;
+        }
+        if (mCols != other.mCols) {
+            return false;
+        }
+        for (int i = 0; i < mMem.length; i++) {
+            if (mMem[i] != other.mMem[i]) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    /** Calculates a hash code of this matrix. */
+    @Override
+    public int hashCode() {
+        int h = mRows * 101 + mCols;
+        for (double m : mMem) {
+            h = h * 37 + Double.hashCode(m);
+        }
+        return h;
+    }
+
+    /**
+     * Returns a string representation of this matrix.
+     *
+     * @return string like "2x2 [a, b; c, d]"
+     */
+    @Override
+    public String toString() {
+        StringBuilder sb = new StringBuilder(mRows * mCols * 8);
+        sb.append(mRows).append("x").append(mCols).append(" [");
+        for (int i = 0; i < mMem.length; i++) {
+            if (i > 0) {
+                sb.append(i % mCols == 0 ? "; " : ", ");
+            }
+            sb.append(mMem[i]);
+        }
+        sb.append("]");
+        return sb.toString();
+    }
+
+    /** Returns the size of the matrix as a String. */
+    private String shortString() {
+        return "(" + mRows + "x" + mCols + ")";
+    }
+}
diff --git a/libraryversions.toml b/libraryversions.toml
index 7007c50..814b43c 100644
--- a/libraryversions.toml
+++ b/libraryversions.toml
@@ -63,6 +63,7 @@
 HEIFWRITER = "1.1.0-alpha02"
 HILT = "1.1.0-alpha01"
 HILT_NAVIGATION_COMPOSE = "1.1.0-alpha01"
+INPUT_MOTIONPREDICTION = "1.0.0-alpha01"
 INSPECTION = "1.0.0"
 INTERPOLATOR = "1.1.0-alpha01"
 JAVASCRIPTENGINE = "1.0.0-alpha03"
@@ -195,6 +196,7 @@
 HEALTH_CONNECT = { group = "androidx.health.connect", atomicGroupVersion = "versions.HEALTH_CONNECT" }
 HEIFWRITER = { group = "androidx.heifwriter", atomicGroupVersion = "versions.HEIFWRITER" }
 HILT = { group = "androidx.hilt" }
+INPUT = { group = "androidx.input" }
 INSPECTION = { group = "androidx.inspection", atomicGroupVersion = "versions.INSPECTION" }
 INSPECTION_EXTENSIONS = { group = "androidx.inspection.extensions", atomicGroupVersion = "versions.SQLITE_INSPECTOR" }
 INTERPOLATOR = { group = "androidx.interpolator", atomicGroupVersion = "versions.INTERPOLATOR" }
diff --git a/settings.gradle b/settings.gradle
index f924444..4079f8f 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -674,6 +674,7 @@
 includeProject(":hilt:hilt-work", [BuildType.MAIN])
 includeProject(":hilt:integration-tests:hilt-testapp-viewmodel", "hilt/integration-tests/viewmodelapp", [BuildType.MAIN])
 includeProject(":hilt:integration-tests:hilt-testapp-worker", "hilt/integration-tests/workerapp", [BuildType.MAIN])
+includeProject(":input:input-motionprediction", [BuildType.MAIN])
 includeProject(":inspection:inspection", [BuildType.MAIN, BuildType.COMPOSE])
 includeProject(":inspection:inspection-gradle-plugin", [BuildType.MAIN])
 includeProject(":inspection:inspection-testing", [BuildType.MAIN, BuildType.COMPOSE])