diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index f59f3ff5cb7582116d1e1f1fb0e1f1cf87bd8114..970376297b308f8e0fdb8f100d370d2189ef24c1 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -56,10 +56,6 @@ void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
 	rcu_barrier();
 }
 
-static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-			  gfn_t start, gfn_t end, bool can_yield, bool flush,
-			  bool shared);
-
 static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
 {
 	free_page((unsigned long)sp->spt);
@@ -82,6 +78,9 @@ static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
 	tdp_mmu_free_sp(sp);
 }
 
+static void tdp_mmu_zap_root(struct kvm *kvm, struct kvm_mmu_page *root,
+			     bool shared);
+
 void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
 			  bool shared)
 {
@@ -104,7 +103,7 @@ void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
 	 * intermediate paging structures, that may be zapped, as such entries
 	 * are associated with the ASID on both VMX and SVM.
 	 */
-	(void)zap_gfn_range(kvm, root, 0, -1ull, false, false, shared);
+	tdp_mmu_zap_root(kvm, root, shared);
 
 	call_rcu(&root->rcu_head, tdp_mmu_free_sp_rcu_callback);
 }
@@ -737,6 +736,76 @@ static inline bool __must_check tdp_mmu_iter_cond_resched(struct kvm *kvm,
 	return iter->yielded;
 }
 
+static inline gfn_t tdp_mmu_max_gfn_host(void)
+{
+	/*
+	 * Bound TDP MMU walks at host.MAXPHYADDR, guest accesses beyond that
+	 * will hit a #PF(RSVD) and never hit an EPT Violation/Misconfig / #NPF,
+	 * and so KVM will never install a SPTE for such addresses.
+	 */
+	return 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+}
+
+static void tdp_mmu_zap_root(struct kvm *kvm, struct kvm_mmu_page *root,
+			     bool shared)
+{
+	bool root_is_unreachable = !refcount_read(&root->tdp_mmu_root_count);
+	struct tdp_iter iter;
+
+	gfn_t end = tdp_mmu_max_gfn_host();
+	gfn_t start = 0;
+
+	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
+
+	rcu_read_lock();
+
+	/*
+	 * No need to try to step down in the iterator when zapping an entire
+	 * root, zapping an upper-level SPTE will recurse on its children.
+	 */
+	for_each_tdp_pte_min_level(iter, root, root->role.level, start, end) {
+retry:
+		/*
+		 * Yielding isn't allowed when zapping an unreachable root as
+		 * the root won't be processed by mmu_notifier callbacks.  When
+		 * handling an unmap/release mmu_notifier command, KVM must
+		 * drop all references to relevant pages prior to completing
+		 * the callback.  Dropping mmu_lock can result in zapping SPTEs
+		 * for an unreachable root after a relevant callback completes,
+		 * which leads to use-after-free as zapping a SPTE triggers
+		 * "writeback" of dirty/accessed bits to the SPTE's associated
+		 * struct page.
+		 */
+		if (!root_is_unreachable &&
+		    tdp_mmu_iter_cond_resched(kvm, &iter, false, shared))
+			continue;
+
+		if (!is_shadow_present_pte(iter.old_spte))
+			continue;
+
+		if (!shared) {
+			tdp_mmu_set_spte(kvm, &iter, 0);
+		} else if (tdp_mmu_set_spte_atomic(kvm, &iter, 0)) {
+			/*
+			 * cmpxchg() shouldn't fail if the root is unreachable.
+			 * Retry so as not to leak the page and its children.
+			 */
+			WARN_ONCE(root_is_unreachable,
+				  "Contended TDP MMU SPTE in unreachable root.");
+			goto retry;
+		}
+
+		/*
+		 * WARN if the root is invalid and is unreachable, all SPTEs
+		 * should've been zapped by kvm_tdp_mmu_zap_invalidated_roots(),
+		 * and inserting new SPTEs under an invalid root is a KVM bug.
+		 */
+		WARN_ON_ONCE(root_is_unreachable && root->role.invalid);
+	}
+
+	rcu_read_unlock();
+}
+
 bool kvm_tdp_mmu_zap_sp(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
 	u64 old_spte;
@@ -785,8 +854,7 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 			  gfn_t start, gfn_t end, bool can_yield, bool flush,
 			  bool shared)
 {
-	gfn_t max_gfn_host = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-	bool zap_all = (start == 0 && end >= max_gfn_host);
+	bool zap_all = (start == 0 && end >= tdp_mmu_max_gfn_host());
 	struct tdp_iter iter;
 
 	/*
@@ -795,12 +863,7 @@ static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
 	 */
 	int min_level = zap_all ? root->role.level : PG_LEVEL_4K;
 
-	/*
-	 * Bound the walk at host.MAXPHYADDR, guest accesses beyond that will
-	 * hit a #PF(RSVD) and never get to an EPT Violation/Misconfig / #NPF,
-	 * and so KVM will never install a SPTE for such addresses.
-	 */
-	end = min(end, max_gfn_host);
+	end = min(end, tdp_mmu_max_gfn_host());
 
 	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
 
@@ -860,6 +923,7 @@ bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
 
 void kvm_tdp_mmu_zap_all(struct kvm *kvm)
 {
+	struct kvm_mmu_page *root;
 	int i;
 
 	/*
@@ -867,8 +931,10 @@ void kvm_tdp_mmu_zap_all(struct kvm *kvm)
 	 * is being destroyed or the userspace VMM has exited.  In both cases,
 	 * KVM_RUN is unreachable, i.e. no vCPUs will ever service the request.
 	 */
-	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
-		(void)kvm_tdp_mmu_zap_gfn_range(kvm, i, 0, -1ull, false);
+	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
+		for_each_tdp_mmu_root_yield_safe(kvm, root, i)
+			tdp_mmu_zap_root(kvm, root, false);
+	}
 }
 
 static struct kvm_mmu_page *next_invalidated_root(struct kvm *kvm,
@@ -925,7 +991,7 @@ void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm)
 		 * will still flush on yield, but that's a minor performance
 		 * blip and not a functional issue.
 		 */
-		(void)zap_gfn_range(kvm, root, 0, -1ull, true, false, true);
+		tdp_mmu_zap_root(kvm, root, true);
 
 		/*
 		 * Put the reference acquired in