diff --git a/drivers/net/ethernet/mellanox/mlxsw/spectrum_buffers.c b/drivers/net/ethernet/mellanox/mlxsw/spectrum_buffers.c
index 637151682cf22002ebe97410ddf08bdabc9c09fc..5fd9a72c8471454633e9ac6843ef1c92873c540f 100644
--- a/drivers/net/ethernet/mellanox/mlxsw/spectrum_buffers.c
+++ b/drivers/net/ethernet/mellanox/mlxsw/spectrum_buffers.c
@@ -35,6 +35,7 @@ struct mlxsw_sp_sb_cm {
 };
 
 #define MLXSW_SP_SB_INFI -1U
+#define MLXSW_SP_SB_REST -2U
 
 struct mlxsw_sp_sb_pm {
 	u32 min_buff;
@@ -421,19 +422,16 @@ static void mlxsw_sp_sb_ports_fini(struct mlxsw_sp *mlxsw_sp)
 		.freeze_size = _freeze_size,				\
 	}
 
-#define MLXSW_SP1_SB_PR_INGRESS_SIZE	13768608
-#define MLXSW_SP1_SB_PR_EGRESS_SIZE	13768608
 #define MLXSW_SP1_SB_PR_CPU_SIZE	(256 * 1000)
 
 /* Order according to mlxsw_sp1_sb_pool_dess */
 static const struct mlxsw_sp_sb_pr mlxsw_sp1_sb_prs[] = {
-	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC,
-		       MLXSW_SP1_SB_PR_INGRESS_SIZE),
+	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, MLXSW_SP_SB_REST),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, 0),
-	MLXSW_SP_SB_PR_EXT(MLXSW_REG_SBPR_MODE_DYNAMIC,
-			   MLXSW_SP1_SB_PR_EGRESS_SIZE, true, false),
+	MLXSW_SP_SB_PR_EXT(MLXSW_REG_SBPR_MODE_DYNAMIC, MLXSW_SP_SB_REST,
+			   true, false),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, 0),
@@ -445,19 +443,16 @@ static const struct mlxsw_sp_sb_pr mlxsw_sp1_sb_prs[] = {
 			   MLXSW_SP1_SB_PR_CPU_SIZE, true, false),
 };
 
-#define MLXSW_SP2_SB_PR_INGRESS_SIZE	34084800
-#define MLXSW_SP2_SB_PR_EGRESS_SIZE	34084800
 #define MLXSW_SP2_SB_PR_CPU_SIZE	(256 * 1000)
 
 /* Order according to mlxsw_sp2_sb_pool_dess */
 static const struct mlxsw_sp_sb_pr mlxsw_sp2_sb_prs[] = {
-	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC,
-		       MLXSW_SP2_SB_PR_INGRESS_SIZE),
+	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_DYNAMIC, MLXSW_SP_SB_REST),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_STATIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_STATIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_STATIC, 0),
-	MLXSW_SP_SB_PR_EXT(MLXSW_REG_SBPR_MODE_DYNAMIC,
-			   MLXSW_SP2_SB_PR_EGRESS_SIZE, true, false),
+	MLXSW_SP_SB_PR_EXT(MLXSW_REG_SBPR_MODE_DYNAMIC, MLXSW_SP_SB_REST,
+			   true, false),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_STATIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_STATIC, 0),
 	MLXSW_SP_SB_PR(MLXSW_REG_SBPR_MODE_STATIC, 0),
@@ -471,11 +466,33 @@ static const struct mlxsw_sp_sb_pr mlxsw_sp2_sb_prs[] = {
 
 static int mlxsw_sp_sb_prs_init(struct mlxsw_sp *mlxsw_sp,
 				const struct mlxsw_sp_sb_pr *prs,
+				const struct mlxsw_sp_sb_pool_des *pool_dess,
 				size_t prs_len)
 {
+	/* Round down, unlike mlxsw_sp_bytes_cells(). */
+	u32 sb_cells = mlxsw_sp->sb->sb_size / mlxsw_sp->sb->cell_size;
+	u32 rest_cells[2] = {sb_cells, sb_cells};
 	int i;
 	int err;
 
+	/* Calculate how much space to give to the "REST" pools in either
+	 * direction.
+	 */
+	for (i = 0; i < prs_len; i++) {
+		enum mlxsw_reg_sbxx_dir dir = pool_dess[i].dir;
+		u32 size = prs[i].size;
+		u32 size_cells;
+
+		if (size == MLXSW_SP_SB_INFI || size == MLXSW_SP_SB_REST)
+			continue;
+
+		size_cells = mlxsw_sp_bytes_cells(mlxsw_sp, size);
+		if (WARN_ON_ONCE(size_cells > rest_cells[dir]))
+			continue;
+
+		rest_cells[dir] -= size_cells;
+	}
+
 	for (i = 0; i < prs_len; i++) {
 		u32 size = prs[i].size;
 		u32 size_cells;
@@ -483,6 +500,10 @@ static int mlxsw_sp_sb_prs_init(struct mlxsw_sp *mlxsw_sp,
 		if (size == MLXSW_SP_SB_INFI) {
 			err = mlxsw_sp_sb_pr_write(mlxsw_sp, i, prs[i].mode,
 						   0, true);
+		} else if (size == MLXSW_SP_SB_REST) {
+			size_cells = rest_cells[pool_dess[i].dir];
+			err = mlxsw_sp_sb_pr_write(mlxsw_sp, i, prs[i].mode,
+						   size_cells, false);
 		} else {
 			size_cells = mlxsw_sp_bytes_cells(mlxsw_sp, size);
 			err = mlxsw_sp_sb_pr_write(mlxsw_sp, i, prs[i].mode,
@@ -926,6 +947,7 @@ int mlxsw_sp_buffers_init(struct mlxsw_sp *mlxsw_sp)
 	if (err)
 		goto err_sb_ports_init;
 	err = mlxsw_sp_sb_prs_init(mlxsw_sp, mlxsw_sp->sb_vals->prs,
+				   mlxsw_sp->sb_vals->pool_dess,
 				   mlxsw_sp->sb_vals->pool_count);
 	if (err)
 		goto err_sb_prs_init;