diff --git a/drivers/base/memory.c b/drivers/base/memory.c
index 6ee2181adc3feb7df73db767f6c85ea198bc3c5e..f75e3467cb59b2a36430b6a64d368f7481d7ae8c 100644
--- a/drivers/base/memory.c
+++ b/drivers/base/memory.c
@@ -215,6 +215,7 @@ static int memory_block_online(struct memory_block *mem)
 		adjust_present_page_count(pfn_to_page(start_pfn), mem->group,
 					  nr_vmemmap_pages);
 
+	mem->zone = zone;
 	return ret;
 }
 
@@ -225,6 +226,9 @@ static int memory_block_offline(struct memory_block *mem)
 	unsigned long nr_vmemmap_pages = mem->nr_vmemmap_pages;
 	int ret;
 
+	if (!mem->zone)
+		return -EINVAL;
+
 	/*
 	 * Unaccount before offlining, such that unpopulated zone and kthreads
 	 * can properly be torn down in offline_pages().
@@ -234,7 +238,7 @@ static int memory_block_offline(struct memory_block *mem)
 					  -nr_vmemmap_pages);
 
 	ret = offline_pages(start_pfn + nr_vmemmap_pages,
-			    nr_pages - nr_vmemmap_pages, mem->group);
+			    nr_pages - nr_vmemmap_pages, mem->zone, mem->group);
 	if (ret) {
 		/* offline_pages() failed. Account back. */
 		if (nr_vmemmap_pages)
@@ -246,6 +250,7 @@ static int memory_block_offline(struct memory_block *mem)
 	if (nr_vmemmap_pages)
 		mhp_deinit_memmap_on_memory(start_pfn, nr_vmemmap_pages);
 
+	mem->zone = NULL;
 	return ret;
 }
 
@@ -411,11 +416,10 @@ static ssize_t valid_zones_show(struct device *dev,
 	 */
 	if (mem->state == MEM_ONLINE) {
 		/*
-		 * The block contains more than one zone can not be offlined.
-		 * This can happen e.g. for ZONE_DMA and ZONE_DMA32
+		 * If !mem->zone, the memory block spans multiple zones and
+		 * cannot get offlined.
 		 */
-		default_zone = test_pages_in_a_zone(start_pfn,
-						    start_pfn + nr_pages);
+		default_zone = mem->zone;
 		if (!default_zone)
 			return sysfs_emit(buf, "%s\n", "none");
 		len += sysfs_emit_at(buf, len, "%s", default_zone->name);
@@ -643,6 +647,82 @@ int register_memory(struct memory_block *memory)
 	return ret;
 }
 
+static struct zone *early_node_zone_for_memory_block(struct memory_block *mem,
+						     int nid)
+{
+	const unsigned long start_pfn = section_nr_to_pfn(mem->start_section_nr);
+	const unsigned long nr_pages = PAGES_PER_SECTION * sections_per_block;
+	struct zone *zone, *matching_zone = NULL;
+	pg_data_t *pgdat = NODE_DATA(nid);
+	int i;
+
+	/*
+	 * This logic only works for early memory, when the applicable zones
+	 * already span the memory block. We don't expect overlapping zones on
+	 * a single node for early memory. So if we're told that some PFNs
+	 * of a node fall into this memory block, we can assume that all node
+	 * zones that intersect with the memory block are actually applicable.
+	 * No need to look at the memmap.
+	 */
+	for (i = 0; i < MAX_NR_ZONES; i++) {
+		zone = pgdat->node_zones + i;
+		if (!populated_zone(zone))
+			continue;
+		if (!zone_intersects(zone, start_pfn, nr_pages))
+			continue;
+		if (!matching_zone) {
+			matching_zone = zone;
+			continue;
+		}
+		/* Spans multiple zones ... */
+		matching_zone = NULL;
+		break;
+	}
+	return matching_zone;
+}
+
+#ifdef CONFIG_NUMA
+/**
+ * memory_block_add_nid() - Indicate that system RAM falling into this memory
+ *			    block device (partially) belongs to the given node.
+ * @mem: The memory block device.
+ * @nid: The node id.
+ * @context: The memory initialization context.
+ *
+ * Indicate that system RAM falling into this memory block (partially) belongs
+ * to the given node. If the context indicates ("early") that we are adding the
+ * node during node device subsystem initialization, this will also properly
+ * set/adjust mem->zone based on the zone ranges of the given node.
+ */
+void memory_block_add_nid(struct memory_block *mem, int nid,
+			  enum meminit_context context)
+{
+	if (context == MEMINIT_EARLY && mem->nid != nid) {
+		/*
+		 * For early memory we have to determine the zone when setting
+		 * the node id and handle multiple nodes spanning a single
+		 * memory block by indicate via zone == NULL that we're not
+		 * dealing with a single zone. So if we're setting the node id
+		 * the first time, determine if there is a single zone. If we're
+		 * setting the node id a second time to a different node,
+		 * invalidate the single detected zone.
+		 */
+		if (mem->nid == NUMA_NO_NODE)
+			mem->zone = early_node_zone_for_memory_block(mem, nid);
+		else
+			mem->zone = NULL;
+	}
+
+	/*
+	 * If this memory block spans multiple nodes, we only indicate
+	 * the last processed node. If we span multiple nodes (not applicable
+	 * to hotplugged memory), zone == NULL will prohibit memory offlining
+	 * and consequently unplug.
+	 */
+	mem->nid = nid;
+}
+#endif
+
 static int init_memory_block(unsigned long block_id, unsigned long state,
 			     unsigned long nr_vmemmap_pages,
 			     struct memory_group *group)
@@ -665,6 +745,17 @@ static int init_memory_block(unsigned long block_id, unsigned long state,
 	mem->nr_vmemmap_pages = nr_vmemmap_pages;
 	INIT_LIST_HEAD(&mem->group_next);
 
+#ifndef CONFIG_NUMA
+	if (state == MEM_ONLINE)
+		/*
+		 * MEM_ONLINE at this point implies early memory. With NUMA,
+		 * we'll determine the zone when setting the node id via
+		 * memory_block_add_nid(). Memory hotplug updated the zone
+		 * manually when memory onlining/offlining succeeds.
+		 */
+		mem->zone = early_node_zone_for_memory_block(mem, NUMA_NO_NODE);
+#endif /* CONFIG_NUMA */
+
 	ret = register_memory(mem);
 	if (ret)
 		return ret;
diff --git a/drivers/base/node.c b/drivers/base/node.c
index 5d75341413ce3b3913bde077e2b9395a480cd9ce..ec8bb24a5a227a80173c58d299da925bd48010d6 100644
--- a/drivers/base/node.c
+++ b/drivers/base/node.c
@@ -796,15 +796,12 @@ static int __ref get_nid_for_pfn(unsigned long pfn)
 }
 
 static void do_register_memory_block_under_node(int nid,
-						struct memory_block *mem_blk)
+						struct memory_block *mem_blk,
+						enum meminit_context context)
 {
 	int ret;
 
-	/*
-	 * If this memory block spans multiple nodes, we only indicate
-	 * the last processed node.
-	 */
-	mem_blk->nid = nid;
+	memory_block_add_nid(mem_blk, nid, context);
 
 	ret = sysfs_create_link_nowarn(&node_devices[nid]->dev.kobj,
 				       &mem_blk->dev.kobj,
@@ -857,7 +854,7 @@ static int register_mem_block_under_node_early(struct memory_block *mem_blk,
 		if (page_nid != nid)
 			continue;
 
-		do_register_memory_block_under_node(nid, mem_blk);
+		do_register_memory_block_under_node(nid, mem_blk, MEMINIT_EARLY);
 		return 0;
 	}
 	/* mem section does not span the specified node */
@@ -873,7 +870,7 @@ static int register_mem_block_under_node_hotplug(struct memory_block *mem_blk,
 {
 	int nid = *(int *)arg;
 
-	do_register_memory_block_under_node(nid, mem_blk);
+	do_register_memory_block_under_node(nid, mem_blk, MEMINIT_HOTPLUG);
 	return 0;
 }
 
diff --git a/include/linux/memory.h b/include/linux/memory.h
index 88eb587b514382b7b5aed884237250550b9356ec..aa619464a1df0cc2cfa7a938042c168eac0d2bbc 100644
--- a/include/linux/memory.h
+++ b/include/linux/memory.h
@@ -70,6 +70,13 @@ struct memory_block {
 	unsigned long state;		/* serialized by the dev->lock */
 	int online_type;		/* for passing data to online routine */
 	int nid;			/* NID for this memory block */
+	/*
+	 * The single zone of this memory block if all PFNs of this memory block
+	 * that are System RAM (not a memory hole, not ZONE_DEVICE ranges) are
+	 * managed by a single zone. NULL if multiple zones (including nodes)
+	 * apply.
+	 */
+	struct zone *zone;
 	struct device dev;
 	/*
 	 * Number of vmemmap pages. These pages
@@ -161,6 +168,11 @@ int walk_dynamic_memory_groups(int nid, walk_memory_groups_func_t func,
 })
 #define register_hotmemory_notifier(nb)		register_memory_notifier(nb)
 #define unregister_hotmemory_notifier(nb) 	unregister_memory_notifier(nb)
+
+#ifdef CONFIG_NUMA
+void memory_block_add_nid(struct memory_block *mem, int nid,
+			  enum meminit_context context);
+#endif /* CONFIG_NUMA */
 #endif	/* CONFIG_MEMORY_HOTPLUG */
 
 /*
diff --git a/include/linux/memory_hotplug.h b/include/linux/memory_hotplug.h
index 76bf2de86defc26ea44cc489ebac1437580a6a2b..1ce6f8044f1ebcd376a22dc7f164c1ae4491ef56 100644
--- a/include/linux/memory_hotplug.h
+++ b/include/linux/memory_hotplug.h
@@ -163,8 +163,6 @@ extern int mhp_init_memmap_on_memory(unsigned long pfn, unsigned long nr_pages,
 extern void mhp_deinit_memmap_on_memory(unsigned long pfn, unsigned long nr_pages);
 extern int online_pages(unsigned long pfn, unsigned long nr_pages,
 			struct zone *zone, struct memory_group *group);
-extern struct zone *test_pages_in_a_zone(unsigned long start_pfn,
-					 unsigned long end_pfn);
 extern void __offline_isolated_pages(unsigned long start_pfn,
 				     unsigned long end_pfn);
 
@@ -293,7 +291,7 @@ static inline void pgdat_resize_init(struct pglist_data *pgdat) {}
 
 extern void try_offline_node(int nid);
 extern int offline_pages(unsigned long start_pfn, unsigned long nr_pages,
-			 struct memory_group *group);
+			 struct zone *zone, struct memory_group *group);
 extern int remove_memory(u64 start, u64 size);
 extern void __remove_memory(u64 start, u64 size);
 extern int offline_and_remove_memory(u64 start, u64 size);
@@ -302,7 +300,7 @@ extern int offline_and_remove_memory(u64 start, u64 size);
 static inline void try_offline_node(int nid) {}
 
 static inline int offline_pages(unsigned long start_pfn, unsigned long nr_pages,
-				struct memory_group *group)
+				struct zone *zone, struct memory_group *group)
 {
 	return -EINVAL;
 }
diff --git a/mm/memory_hotplug.c b/mm/memory_hotplug.c
index ed1a5dac67978c9f5a1504842104c6a0331f69ce..aee69281dad682560afbb7838786e51b07c58a4e 100644
--- a/mm/memory_hotplug.c
+++ b/mm/memory_hotplug.c
@@ -1548,38 +1548,6 @@ bool mhp_range_allowed(u64 start, u64 size, bool need_mapping)
 }
 
 #ifdef CONFIG_MEMORY_HOTREMOVE
-/*
- * Confirm all pages in a range [start, end) belong to the same zone (skipping
- * memory holes). When true, return the zone.
- */
-struct zone *test_pages_in_a_zone(unsigned long start_pfn,
-				  unsigned long end_pfn)
-{
-	unsigned long pfn, sec_end_pfn;
-	struct zone *zone = NULL;
-	struct page *page;
-
-	for (pfn = start_pfn, sec_end_pfn = SECTION_ALIGN_UP(start_pfn + 1);
-	     pfn < end_pfn;
-	     pfn = sec_end_pfn, sec_end_pfn += PAGES_PER_SECTION) {
-		/* Make sure the memory section is present first */
-		if (!present_section_nr(pfn_to_section_nr(pfn)))
-			continue;
-		for (; pfn < sec_end_pfn && pfn < end_pfn;
-		     pfn += MAX_ORDER_NR_PAGES) {
-			/* Check if we got outside of the zone */
-			if (zone && !zone_spans_pfn(zone, pfn))
-				return NULL;
-			page = pfn_to_page(pfn);
-			if (zone && page_zone(page) != zone)
-				return NULL;
-			zone = page_zone(page);
-		}
-	}
-
-	return zone;
-}
-
 /*
  * Scan pfn range [start,end) to find movable/migratable pages (LRU pages,
  * non-lru movable pages and hugepages). Will skip over most unmovable
@@ -1803,15 +1771,15 @@ static int count_system_ram_pages_cb(unsigned long start_pfn,
 }
 
 int __ref offline_pages(unsigned long start_pfn, unsigned long nr_pages,
-			struct memory_group *group)
+			struct zone *zone, struct memory_group *group)
 {
 	const unsigned long end_pfn = start_pfn + nr_pages;
 	unsigned long pfn, system_ram_pages = 0;
+	const int node = zone_to_nid(zone);
 	unsigned long flags;
-	struct zone *zone;
 	struct memory_notify arg;
-	int ret, node;
 	char *reason;
+	int ret;
 
 	/*
 	 * {on,off}lining is constrained to full memory sections (or more
@@ -1843,15 +1811,17 @@ int __ref offline_pages(unsigned long start_pfn, unsigned long nr_pages,
 		goto failed_removal;
 	}
 
-	/* This makes hotplug much easier...and readable.
-	   we assume this for now. .*/
-	zone = test_pages_in_a_zone(start_pfn, end_pfn);
-	if (!zone) {
+	/*
+	 * We only support offlining of memory blocks managed by a single zone,
+	 * checked by calling code. This is just a sanity check that we might
+	 * want to remove in the future.
+	 */
+	if (WARN_ON_ONCE(page_zone(pfn_to_page(start_pfn)) != zone ||
+			 page_zone(pfn_to_page(end_pfn - 1)) != zone)) {
 		ret = -EINVAL;
 		reason = "multizone range";
 		goto failed_removal;
 	}
-	node = zone_to_nid(zone);
 
 	/*
 	 * Disable pcplists so that page isolation cannot race with freeing