diff --git a/drivers/net/dsa/b53/b53_common.c b/drivers/net/dsa/b53/b53_common.c
index 7f26f5dafca7e6430f7e218423bf2da400a4dfa1..561b05089cb68d98313e978626db09435e06d637 100644
--- a/drivers/net/dsa/b53/b53_common.c
+++ b/drivers/net/dsa/b53/b53_common.c
@@ -1029,8 +1029,7 @@ int b53_vlan_filtering(struct dsa_switch *ds, int port, bool vlan_filtering)
 EXPORT_SYMBOL(b53_vlan_filtering);
 
 int b53_vlan_prepare(struct dsa_switch *ds, int port,
-		     const struct switchdev_obj_port_vlan *vlan,
-		     struct switchdev_trans *trans)
+		     const struct switchdev_obj_port_vlan *vlan)
 {
 	struct b53_device *dev = ds->priv;
 
@@ -1047,8 +1046,7 @@ int b53_vlan_prepare(struct dsa_switch *ds, int port,
 EXPORT_SYMBOL(b53_vlan_prepare);
 
 void b53_vlan_add(struct dsa_switch *ds, int port,
-		  const struct switchdev_obj_port_vlan *vlan,
-		  struct switchdev_trans *trans)
+		  const struct switchdev_obj_port_vlan *vlan)
 {
 	struct b53_device *dev = ds->priv;
 	bool untagged = vlan->flags & BRIDGE_VLAN_INFO_UNTAGGED;
diff --git a/drivers/net/dsa/b53/b53_priv.h b/drivers/net/dsa/b53/b53_priv.h
index 2af0155efce2e9088ac80cf9f95803585a501da4..d954cf36ecd805b8c2a83371d3c9e443b04506fe 100644
--- a/drivers/net/dsa/b53/b53_priv.h
+++ b/drivers/net/dsa/b53/b53_priv.h
@@ -295,11 +295,9 @@ void b53_br_set_stp_state(struct dsa_switch *ds, int port, u8 state);
 void b53_br_fast_age(struct dsa_switch *ds, int port);
 int b53_vlan_filtering(struct dsa_switch *ds, int port, bool vlan_filtering);
 int b53_vlan_prepare(struct dsa_switch *ds, int port,
-		     const struct switchdev_obj_port_vlan *vlan,
-		     struct switchdev_trans *trans);
+		     const struct switchdev_obj_port_vlan *vlan);
 void b53_vlan_add(struct dsa_switch *ds, int port,
-		  const struct switchdev_obj_port_vlan *vlan,
-		  struct switchdev_trans *trans);
+		  const struct switchdev_obj_port_vlan *vlan);
 int b53_vlan_del(struct dsa_switch *ds, int port,
 		 const struct switchdev_obj_port_vlan *vlan);
 int b53_fdb_add(struct dsa_switch *ds, int port,
diff --git a/drivers/net/dsa/dsa_loop.c b/drivers/net/dsa/dsa_loop.c
index bb71d3d6f65b008dc5195b625ef579af848fea26..7aa84ee4e771d97b031e1e1e66910f81c1958828 100644
--- a/drivers/net/dsa/dsa_loop.c
+++ b/drivers/net/dsa/dsa_loop.c
@@ -174,9 +174,9 @@ static int dsa_loop_port_vlan_filtering(struct dsa_switch *ds, int port,
 	return 0;
 }
 
-static int dsa_loop_port_vlan_prepare(struct dsa_switch *ds, int port,
-				      const struct switchdev_obj_port_vlan *vlan,
-				      struct switchdev_trans *trans)
+static int
+dsa_loop_port_vlan_prepare(struct dsa_switch *ds, int port,
+			   const struct switchdev_obj_port_vlan *vlan)
 {
 	struct dsa_loop_priv *ps = ds->priv;
 	struct mii_bus *bus = ps->bus;
@@ -193,8 +193,7 @@ static int dsa_loop_port_vlan_prepare(struct dsa_switch *ds, int port,
 }
 
 static void dsa_loop_port_vlan_add(struct dsa_switch *ds, int port,
-				   const struct switchdev_obj_port_vlan *vlan,
-				   struct switchdev_trans *trans)
+				   const struct switchdev_obj_port_vlan *vlan)
 {
 	bool untagged = vlan->flags & BRIDGE_VLAN_INFO_UNTAGGED;
 	bool pvid = vlan->flags & BRIDGE_VLAN_INFO_PVID;
diff --git a/drivers/net/dsa/lan9303-core.c b/drivers/net/dsa/lan9303-core.c
index b24566bb74d2b789d821ca35d995001b3db3c8ae..ea59dadefb337bb6b00e88d3c5c63458a214fac9 100644
--- a/drivers/net/dsa/lan9303-core.c
+++ b/drivers/net/dsa/lan9303-core.c
@@ -1217,8 +1217,7 @@ static int lan9303_port_fdb_dump(struct dsa_switch *ds, int port,
 }
 
 static int lan9303_port_mdb_prepare(struct dsa_switch *ds, int port,
-				    const struct switchdev_obj_port_mdb *mdb,
-				    struct switchdev_trans *trans)
+				    const struct switchdev_obj_port_mdb *mdb)
 {
 	struct lan9303 *chip = ds->priv;
 
@@ -1235,8 +1234,7 @@ static int lan9303_port_mdb_prepare(struct dsa_switch *ds, int port,
 }
 
 static void lan9303_port_mdb_add(struct dsa_switch *ds, int port,
-				 const struct switchdev_obj_port_mdb *mdb,
-				 struct switchdev_trans *trans)
+				 const struct switchdev_obj_port_mdb *mdb)
 {
 	struct lan9303 *chip = ds->priv;
 
diff --git a/drivers/net/dsa/microchip/ksz_common.c b/drivers/net/dsa/microchip/ksz_common.c
index b5be93a1e0df88a5fcf44dce1d85fdfb4df17ff4..663b0d5b982b127f934a8c87e64f5aa8209a6b0b 100644
--- a/drivers/net/dsa/microchip/ksz_common.c
+++ b/drivers/net/dsa/microchip/ksz_common.c
@@ -559,8 +559,7 @@ static int ksz_port_vlan_filtering(struct dsa_switch *ds, int port, bool flag)
 }
 
 static int ksz_port_vlan_prepare(struct dsa_switch *ds, int port,
-				 const struct switchdev_obj_port_vlan *vlan,
-				 struct switchdev_trans *trans)
+				 const struct switchdev_obj_port_vlan *vlan)
 {
 	/* nothing needed */
 
@@ -568,8 +567,7 @@ static int ksz_port_vlan_prepare(struct dsa_switch *ds, int port,
 }
 
 static void ksz_port_vlan_add(struct dsa_switch *ds, int port,
-			      const struct switchdev_obj_port_vlan *vlan,
-			      struct switchdev_trans *trans)
+			      const struct switchdev_obj_port_vlan *vlan)
 {
 	struct ksz_device *dev = ds->priv;
 	u32 vlan_table[3];
@@ -858,16 +856,14 @@ static int ksz_port_fdb_dump(struct dsa_switch *ds, int port,
 }
 
 static int ksz_port_mdb_prepare(struct dsa_switch *ds, int port,
-				const struct switchdev_obj_port_mdb *mdb,
-				struct switchdev_trans *trans)
+				const struct switchdev_obj_port_mdb *mdb)
 {
 	/* nothing to do */
 	return 0;
 }
 
 static void ksz_port_mdb_add(struct dsa_switch *ds, int port,
-			     const struct switchdev_obj_port_mdb *mdb,
-			     struct switchdev_trans *trans)
+			     const struct switchdev_obj_port_mdb *mdb)
 {
 	struct ksz_device *dev = ds->priv;
 	u32 static_table[4];
diff --git a/drivers/net/dsa/mv88e6xxx/chip.c b/drivers/net/dsa/mv88e6xxx/chip.c
index 8171055fde7a0238fb2fbc691a482c211d4d8d5b..b5e0987c88f01430f2a8fd2336d156d2d7496d89 100644
--- a/drivers/net/dsa/mv88e6xxx/chip.c
+++ b/drivers/net/dsa/mv88e6xxx/chip.c
@@ -1185,8 +1185,7 @@ static int mv88e6xxx_port_vlan_filtering(struct dsa_switch *ds, int port,
 
 static int
 mv88e6xxx_port_vlan_prepare(struct dsa_switch *ds, int port,
-			    const struct switchdev_obj_port_vlan *vlan,
-			    struct switchdev_trans *trans)
+			    const struct switchdev_obj_port_vlan *vlan)
 {
 	struct mv88e6xxx_chip *chip = ds->priv;
 	int err;
@@ -1295,8 +1294,7 @@ static int _mv88e6xxx_port_vlan_add(struct mv88e6xxx_chip *chip, int port,
 }
 
 static void mv88e6xxx_port_vlan_add(struct dsa_switch *ds, int port,
-				    const struct switchdev_obj_port_vlan *vlan,
-				    struct switchdev_trans *trans)
+				    const struct switchdev_obj_port_vlan *vlan)
 {
 	struct mv88e6xxx_chip *chip = ds->priv;
 	bool untagged = vlan->flags & BRIDGE_VLAN_INFO_UNTAGGED;
@@ -3788,8 +3786,7 @@ static const char *mv88e6xxx_drv_probe(struct device *dsa_dev,
 }
 
 static int mv88e6xxx_port_mdb_prepare(struct dsa_switch *ds, int port,
-				      const struct switchdev_obj_port_mdb *mdb,
-				      struct switchdev_trans *trans)
+				      const struct switchdev_obj_port_mdb *mdb)
 {
 	/* We don't need any dynamic resource from the kernel (yet),
 	 * so skip the prepare phase.
@@ -3799,8 +3796,7 @@ static int mv88e6xxx_port_mdb_prepare(struct dsa_switch *ds, int port,
 }
 
 static void mv88e6xxx_port_mdb_add(struct dsa_switch *ds, int port,
-				   const struct switchdev_obj_port_mdb *mdb,
-				   struct switchdev_trans *trans)
+				   const struct switchdev_obj_port_mdb *mdb)
 {
 	struct mv88e6xxx_chip *chip = ds->priv;
 
diff --git a/include/net/dsa.h b/include/net/dsa.h
index 2a05738570d83c4976b764c86ecb1b12aa697896..6700dff46a80747396e4037730959b0fa1188890 100644
--- a/include/net/dsa.h
+++ b/include/net/dsa.h
@@ -412,12 +412,10 @@ struct dsa_switch_ops {
 	 */
 	int	(*port_vlan_filtering)(struct dsa_switch *ds, int port,
 				       bool vlan_filtering);
-	int	(*port_vlan_prepare)(struct dsa_switch *ds, int port,
-				     const struct switchdev_obj_port_vlan *vlan,
-				     struct switchdev_trans *trans);
-	void	(*port_vlan_add)(struct dsa_switch *ds, int port,
-				 const struct switchdev_obj_port_vlan *vlan,
-				 struct switchdev_trans *trans);
+	int (*port_vlan_prepare)(struct dsa_switch *ds, int port,
+				 const struct switchdev_obj_port_vlan *vlan);
+	void (*port_vlan_add)(struct dsa_switch *ds, int port,
+			      const struct switchdev_obj_port_vlan *vlan);
 	int	(*port_vlan_del)(struct dsa_switch *ds, int port,
 				 const struct switchdev_obj_port_vlan *vlan);
 	/*
@@ -433,12 +431,10 @@ struct dsa_switch_ops {
 	/*
 	 * Multicast database
 	 */
-	int	(*port_mdb_prepare)(struct dsa_switch *ds, int port,
-				    const struct switchdev_obj_port_mdb *mdb,
-				    struct switchdev_trans *trans);
-	void	(*port_mdb_add)(struct dsa_switch *ds, int port,
-				const struct switchdev_obj_port_mdb *mdb,
-				struct switchdev_trans *trans);
+	int (*port_mdb_prepare)(struct dsa_switch *ds, int port,
+				const struct switchdev_obj_port_mdb *mdb);
+	void (*port_mdb_add)(struct dsa_switch *ds, int port,
+			     const struct switchdev_obj_port_mdb *mdb);
 	int	(*port_mdb_del)(struct dsa_switch *ds, int port,
 				const struct switchdev_obj_port_mdb *mdb);
 	/*
diff --git a/net/dsa/switch.c b/net/dsa/switch.c
index 29608d087a7c56a9be475fca32a69f52de985f00..9a01514ea9f3312c41ba94d08da63943c8a2ceae 100644
--- a/net/dsa/switch.c
+++ b/net/dsa/switch.c
@@ -108,13 +108,42 @@ static int dsa_switch_fdb_del(struct dsa_switch *ds,
 				     info->vid);
 }
 
+static int
+dsa_switch_mdb_prepare_bitmap(struct dsa_switch *ds,
+			      const struct switchdev_obj_port_mdb *mdb,
+			      const unsigned long *bitmap)
+{
+	int port, err;
+
+	if (!ds->ops->port_mdb_prepare || !ds->ops->port_mdb_add)
+		return -EOPNOTSUPP;
+
+	for_each_set_bit(port, bitmap, ds->num_ports) {
+		err = ds->ops->port_mdb_prepare(ds, port, mdb);
+		if (err)
+			return err;
+	}
+
+	return 0;
+}
+
+static void dsa_switch_mdb_add_bitmap(struct dsa_switch *ds,
+				      const struct switchdev_obj_port_mdb *mdb,
+				      const unsigned long *bitmap)
+{
+	int port;
+
+	for_each_set_bit(port, bitmap, ds->num_ports)
+		ds->ops->port_mdb_add(ds, port, mdb);
+}
+
 static int dsa_switch_mdb_add(struct dsa_switch *ds,
 			      struct dsa_notifier_mdb_info *info)
 {
 	const struct switchdev_obj_port_mdb *mdb = info->mdb;
 	struct switchdev_trans *trans = info->trans;
 	DECLARE_BITMAP(group, ds->num_ports);
-	int port, err;
+	int port;
 
 	/* Build a mask of Multicast group members */
 	bitmap_zero(group, ds->num_ports);
@@ -124,21 +153,10 @@ static int dsa_switch_mdb_add(struct dsa_switch *ds,
 		if (dsa_is_dsa_port(ds, port))
 			set_bit(port, group);
 
-	if (switchdev_trans_ph_prepare(trans)) {
-		if (!ds->ops->port_mdb_prepare || !ds->ops->port_mdb_add)
-			return -EOPNOTSUPP;
-
-		for_each_set_bit(port, group, ds->num_ports) {
-			err = ds->ops->port_mdb_prepare(ds, port, mdb, trans);
-			if (err)
-				return err;
-		}
-
-		return 0;
-	}
+	if (switchdev_trans_ph_prepare(trans))
+		return dsa_switch_mdb_prepare_bitmap(ds, mdb, group);
 
-	for_each_set_bit(port, group, ds->num_ports)
-		ds->ops->port_mdb_add(ds, port, mdb, trans);
+	dsa_switch_mdb_add_bitmap(ds, mdb, group);
 
 	return 0;
 }
@@ -157,13 +175,43 @@ static int dsa_switch_mdb_del(struct dsa_switch *ds,
 	return 0;
 }
 
+static int
+dsa_switch_vlan_prepare_bitmap(struct dsa_switch *ds,
+			       const struct switchdev_obj_port_vlan *vlan,
+			       const unsigned long *bitmap)
+{
+	int port, err;
+
+	if (!ds->ops->port_vlan_prepare || !ds->ops->port_vlan_add)
+		return -EOPNOTSUPP;
+
+	for_each_set_bit(port, bitmap, ds->num_ports) {
+		err = ds->ops->port_vlan_prepare(ds, port, vlan);
+		if (err)
+			return err;
+	}
+
+	return 0;
+}
+
+static void
+dsa_switch_vlan_add_bitmap(struct dsa_switch *ds,
+			   const struct switchdev_obj_port_vlan *vlan,
+			   const unsigned long *bitmap)
+{
+	int port;
+
+	for_each_set_bit(port, bitmap, ds->num_ports)
+		ds->ops->port_vlan_add(ds, port, vlan);
+}
+
 static int dsa_switch_vlan_add(struct dsa_switch *ds,
 			       struct dsa_notifier_vlan_info *info)
 {
 	const struct switchdev_obj_port_vlan *vlan = info->vlan;
 	struct switchdev_trans *trans = info->trans;
 	DECLARE_BITMAP(members, ds->num_ports);
-	int port, err;
+	int port;
 
 	/* Build a mask of VLAN members */
 	bitmap_zero(members, ds->num_ports);
@@ -173,21 +221,10 @@ static int dsa_switch_vlan_add(struct dsa_switch *ds,
 		if (dsa_is_cpu_port(ds, port) || dsa_is_dsa_port(ds, port))
 			set_bit(port, members);
 
-	if (switchdev_trans_ph_prepare(trans)) {
-		if (!ds->ops->port_vlan_prepare || !ds->ops->port_vlan_add)
-			return -EOPNOTSUPP;
-
-		for_each_set_bit(port, members, ds->num_ports) {
-			err = ds->ops->port_vlan_prepare(ds, port, vlan, trans);
-			if (err)
-				return err;
-		}
-
-		return 0;
-	}
+	if (switchdev_trans_ph_prepare(trans))
+		return dsa_switch_vlan_prepare_bitmap(ds, vlan, members);
 
-	for_each_set_bit(port, members, ds->num_ports)
-		ds->ops->port_vlan_add(ds, port, vlan, trans);
+	dsa_switch_vlan_add_bitmap(ds, vlan, members);
 
 	return 0;
 }