天天看點

LeetCode5852之最小化目标值與所選元素的差(相關話題:回溯,記憶化搜尋,剪枝,歸并排序)

題目描述

給你一個大小為 

m x n

 的整數矩陣 

mat

 和一個整數 

target

 。

從矩陣的 每一行 中選擇一個整數,你的目标是 最小化 所有選中元素之 和 與目标值 

target

 的 絕對差 。傳回 最小的絕對差 。

a

 和 

b

 兩數字的 絕對差 是 

a - b

 的絕對值。

示例 1:

LeetCode5852之最小化目标值與所選元素的差(相關話題:回溯,記憶化搜尋,剪枝,歸并排序)
輸入:mat = [[1,2,3],[4,5,6],[7,8,9]], target = 13
輸出:0
解釋:一種可能的最優選擇方案是:
- 第一行選出 1
- 第二行選出 5
- 第三行選出 7
所選元素的和是 13 ,等于目标值,是以絕對差是 0 。
           

示例 2:

LeetCode5852之最小化目标值與所選元素的差(相關話題:回溯,記憶化搜尋,剪枝,歸并排序)
輸入:mat = [[1],[2],[3]], target = 100
輸出:94
解釋:唯一一種選擇方案是:
- 第一行選出 1
- 第二行選出 2
- 第三行選出 3
所選元素的和是 6 ,絕對差是 94 。
           

示例 3:

LeetCode5852之最小化目标值與所選元素的差(相關話題:回溯,記憶化搜尋,剪枝,歸并排序)
輸入:mat = [[1,2,9,8,7]], target = 6
輸出:1
解釋:最優的選擇方案是選出第一行的 7 。
絕對差是 1 。
提示:           
  • m == mat.length

  • n == mat[i].length

  • 1 <= m, n <= 70

  • 1 <= mat[i][j] <= 70

  • 1 <= target <= 800

思路分析

這題第一眼看用回溯算法,但是會逾時,因為不同行比如(3,4),(2,5)選擇後剩餘的值都為target-7,這裡面存在大量重複的計算。是以引入記憶化搜尋的技巧:因為

1 <= mat[i][j] <= 70,是以定義int[][] mem = new int[row][target + 70 * row + 1];

dfs函數index代表矩陣的行數,targer代表剩餘的目标值

int result = Integer.MAX_VALUE;
	int[][] mem;

	public int minimizeTheDifference(int[][] mat, int target) {

		int row = mat.length;
		for (int i = 0; i < row; i++) {
			Arrays.sort(mat[i]);
		}
		mem = new int[row][target + 70 * row + 1];
		// 注意一點要有這句,,否者力扣編譯器裡無法識别mem的初始值0
		for (int i = 0; i < row; i++) {
			Arrays.fill(mem[i], -1);
		}
		result = backTrace(mat, 0, target);

		return result;

	}

	public int backTrace(int[][] mat, int index, int target) {

		int row = mat.length;
		int col = mat[0].length;

		if (index < row && mem[index][target + (70 * row)] != -1) {
			return mem[index][target + (70 * row)];
		}
		if (index == row) {
			return Math.abs(target);
		}

		for (int i = 0; i < col; i++) {

			int temp = backTrace(mat, index + 1, target - mat[index][i]);
			result = Math.min(result, temp);

		}
		mem[index][target + (70 * row)] = result;
		return result;
	}           

還有一個優化點在于剪枝

// 前面的數如果比target大,那麼目前數肯定不是最優解可以直接跳過後面的循環判斷(沒這個剪枝其實也能通過)
if (i > 0 && target < mat[index][i - 1]) {
	break;
}           
// 注意一點要有這句,,否者力扣編譯器裡無法識别mem的初始值0
for (int i = 0; i < row; i++) {
	Arrays.fill(mem[i], -1);
}           

代碼實作

int result = Integer.MAX_VALUE;
	int[][] mem;

	public int minimizeTheDifference(int[][] mat, int target) {

		int row = mat.length;
		for (int i = 0; i < row; i++) {
			Arrays.sort(mat[i]);
		}
		mem = new int[row][target + 70 * row + 1];
		// 注意一點要有這句,,否者力扣編譯器裡無法識别mem的初始值0
		for (int i = 0; i < row; i++) {
			Arrays.fill(mem[i], -1);
		}
		result = dfs(mat, 0, target);

		return result;

	}

	public int backTrace(int[][] mat, int index, int target) {

		int row = mat.length;
		int col = mat[0].length;

		if (index < row && mem[index][target + (70 * row)] != -1) {
			return mem[index][target + (70 * row)];
		}
		if (index == row) {
			return Math.abs(target);
		}

		for (int i = 0; i < col; i++) {

			// 前面的數如果比target大,那麼目前數肯定不是最優解可以直接跳過後面的循環判斷(沒這個剪枝其實也能通過)
			if (i > 0 && target < mat[index][i - 1]) {
				break;
			}
			int temp = backTrace(mat, index + 1, target - mat[index][i]);
			result = Math.min(result, temp);

		}
		mem[index][target + (70 * row)] = result;
		return result;
	}           

思維拓展

class Solution {

     public int minimizeTheDifference(int[][] mat, int target) {
            int m=mat.length,n=mat[0].length;
            int i = 0, j = 0;
            for(i=0;i<m;i++){
                Arrays.sort(mat[i]);
            }
            int k=0;
            Set<Integer>[] sets=new HashSet[m];
            for(i=0;i<m;i++){
                sets[i]=new HashSet<>();
                for(j=0;j<n;j++){
                    sets[i].add(mat[i][j]);
                }
            }

            merge(sets,0,m-1);

            int min=Integer.MAX_VALUE,ans=0;
            for (Integer num : sets[0]) {
                if(Math.abs(num-target)<min){
                    min=Math.abs(num-target);
                    ans=num;
                }
            }

            //System.out.println(sets[0]);

            return min;
        }

        public void merge(Set<Integer>[] sets,int l,int r){
            if(l>=r){
                return;
            }
            int mid=l+r>>1;
            merge(sets,l,mid);
            merge(sets,mid+1,r);
            int m=sets[l].size(),n=sets[mid+1].size();
            int i=0,j=0;
            Set<Integer> temp=new HashSet<>();
            int[] a = new int[m];
            int[] b = new int[n];
            for (Integer num : sets[l]) {
                a[i++]=num;
            }
            for (Integer num : sets[mid + 1]) {
                b[j++]=num;
            }
            for(i=0;i<m;i++){
                for(j=0;j<n;j++){
                    temp.add(a[i]+b[j]);
                }
            }
            sets[l]=new HashSet<>(temp);
        }
 
}