天天看点

LeetCode5852之最小化目标值与所选元素的差(相关话题:回溯,记忆化搜索,剪枝,归并排序)

题目描述

给你一个大小为 

m x n

 的整数矩阵 

mat

 和一个整数 

target

 。

从矩阵的 每一行 中选择一个整数,你的目标是 最小化 所有选中元素之 和 与目标值 

target

 的 绝对差 。返回 最小的绝对差 。

a

 和 

b

 两数字的 绝对差 是 

a - b

 的绝对值。

示例 1:

LeetCode5852之最小化目标值与所选元素的差(相关话题:回溯,记忆化搜索,剪枝,归并排序)
输入:mat = [[1,2,3],[4,5,6],[7,8,9]], target = 13
输出:0
解释:一种可能的最优选择方案是:
- 第一行选出 1
- 第二行选出 5
- 第三行选出 7
所选元素的和是 13 ,等于目标值,所以绝对差是 0 。
           

示例 2:

LeetCode5852之最小化目标值与所选元素的差(相关话题:回溯,记忆化搜索,剪枝,归并排序)
输入:mat = [[1],[2],[3]], target = 100
输出:94
解释:唯一一种选择方案是:
- 第一行选出 1
- 第二行选出 2
- 第三行选出 3
所选元素的和是 6 ,绝对差是 94 。
           

示例 3:

LeetCode5852之最小化目标值与所选元素的差(相关话题:回溯,记忆化搜索,剪枝,归并排序)
输入:mat = [[1,2,9,8,7]], target = 6
输出:1
解释:最优的选择方案是选出第一行的 7 。
绝对差是 1 。
提示:           
  • m == mat.length

  • n == mat[i].length

  • 1 <= m, n <= 70

  • 1 <= mat[i][j] <= 70

  • 1 <= target <= 800

思路分析

这题第一眼看用回溯算法,但是会超时,因为不同行比如(3,4),(2,5)选择后剩余的值都为target-7,这里面存在大量重复的计算。所以引入记忆化搜索的技巧:因为

1 <= mat[i][j] <= 70,所以定义int[][] mem = new int[row][target + 70 * row + 1];

dfs函数index代表矩阵的行数,targer代表剩余的目标值

int result = Integer.MAX_VALUE;
	int[][] mem;

	public int minimizeTheDifference(int[][] mat, int target) {

		int row = mat.length;
		for (int i = 0; i < row; i++) {
			Arrays.sort(mat[i]);
		}
		mem = new int[row][target + 70 * row + 1];
		// 注意一点要有这句,,否者力扣编译器里无法识别mem的初始值0
		for (int i = 0; i < row; i++) {
			Arrays.fill(mem[i], -1);
		}
		result = backTrace(mat, 0, target);

		return result;

	}

	public int backTrace(int[][] mat, int index, int target) {

		int row = mat.length;
		int col = mat[0].length;

		if (index < row && mem[index][target + (70 * row)] != -1) {
			return mem[index][target + (70 * row)];
		}
		if (index == row) {
			return Math.abs(target);
		}

		for (int i = 0; i < col; i++) {

			int temp = backTrace(mat, index + 1, target - mat[index][i]);
			result = Math.min(result, temp);

		}
		mem[index][target + (70 * row)] = result;
		return result;
	}           

还有一个优化点在于剪枝

// 前面的数如果比target大,那么当前数肯定不是最优解可以直接跳过后面的循环判断(没这个剪枝其实也能通过)
if (i > 0 && target < mat[index][i - 1]) {
	break;
}           
// 注意一点要有这句,,否者力扣编译器里无法识别mem的初始值0
for (int i = 0; i < row; i++) {
	Arrays.fill(mem[i], -1);
}           

代码实现

int result = Integer.MAX_VALUE;
	int[][] mem;

	public int minimizeTheDifference(int[][] mat, int target) {

		int row = mat.length;
		for (int i = 0; i < row; i++) {
			Arrays.sort(mat[i]);
		}
		mem = new int[row][target + 70 * row + 1];
		// 注意一点要有这句,,否者力扣编译器里无法识别mem的初始值0
		for (int i = 0; i < row; i++) {
			Arrays.fill(mem[i], -1);
		}
		result = dfs(mat, 0, target);

		return result;

	}

	public int backTrace(int[][] mat, int index, int target) {

		int row = mat.length;
		int col = mat[0].length;

		if (index < row && mem[index][target + (70 * row)] != -1) {
			return mem[index][target + (70 * row)];
		}
		if (index == row) {
			return Math.abs(target);
		}

		for (int i = 0; i < col; i++) {

			// 前面的数如果比target大,那么当前数肯定不是最优解可以直接跳过后面的循环判断(没这个剪枝其实也能通过)
			if (i > 0 && target < mat[index][i - 1]) {
				break;
			}
			int temp = backTrace(mat, index + 1, target - mat[index][i]);
			result = Math.min(result, temp);

		}
		mem[index][target + (70 * row)] = result;
		return result;
	}           

思维拓展

class Solution {

     public int minimizeTheDifference(int[][] mat, int target) {
            int m=mat.length,n=mat[0].length;
            int i = 0, j = 0;
            for(i=0;i<m;i++){
                Arrays.sort(mat[i]);
            }
            int k=0;
            Set<Integer>[] sets=new HashSet[m];
            for(i=0;i<m;i++){
                sets[i]=new HashSet<>();
                for(j=0;j<n;j++){
                    sets[i].add(mat[i][j]);
                }
            }

            merge(sets,0,m-1);

            int min=Integer.MAX_VALUE,ans=0;
            for (Integer num : sets[0]) {
                if(Math.abs(num-target)<min){
                    min=Math.abs(num-target);
                    ans=num;
                }
            }

            //System.out.println(sets[0]);

            return min;
        }

        public void merge(Set<Integer>[] sets,int l,int r){
            if(l>=r){
                return;
            }
            int mid=l+r>>1;
            merge(sets,l,mid);
            merge(sets,mid+1,r);
            int m=sets[l].size(),n=sets[mid+1].size();
            int i=0,j=0;
            Set<Integer> temp=new HashSet<>();
            int[] a = new int[m];
            int[] b = new int[n];
            for (Integer num : sets[l]) {
                a[i++]=num;
            }
            for (Integer num : sets[mid + 1]) {
                b[j++]=num;
            }
            for(i=0;i<m;i++){
                for(j=0;j<n;j++){
                    temp.add(a[i]+b[j]);
                }
            }
            sets[l]=new HashSet<>(temp);
        }
 
}